/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.opencl.mm;

import java.lang.foreign.MemorySegment;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import uk.ac.manchester.tornado.api.common.Access;
import uk.ac.manchester.tornado.api.exceptions.TornadoInternalError;
import uk.ac.manchester.tornado.api.exceptions.TornadoMemoryException;
import uk.ac.manchester.tornado.api.exceptions.TornadoOutOfMemoryException;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.api.memory.XPUBuffer;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.api.types.collections.TornadoCollectionInterface;
import uk.ac.manchester.tornado.api.types.images.TornadoImagesInterface;
import uk.ac.manchester.tornado.api.types.matrix.TornadoMatrixInterface;
import uk.ac.manchester.tornado.api.types.volumes.TornadoVolumesInterface;
import uk.ac.manchester.tornado.drivers.opencl.OCLDeviceContext;
import uk.ac.manchester.tornado.runtime.common.TornadoLogger;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
import uk.ac.manchester.tornado.runtime.common.exceptions.TornadoUnsupportedError;

public class OCLMemorySegmentWrapper
implements XPUBuffer {
    private static final int INIT_VALUE = -1;
    private final OCLDeviceContext deviceContext;
    private final long batchSize;
    private long bufferId;
    private long bufferOffset;
    private long bufferSize;
    private long subregionSize;
    private Access access;
    private final int sizeOfType;

    public OCLMemorySegmentWrapper(long bufferSize, OCLDeviceContext deviceContext, long batchSize, Access access, int sizeOfType) {
        this.deviceContext = deviceContext;
        this.batchSize = batchSize;
        this.bufferSize = bufferSize;
        this.bufferId = -1L;
        this.bufferOffset = 0L;
        this.access = access;
        this.sizeOfType = sizeOfType;
        if (sizeOfType <= 0) {
            throw new TornadoRuntimeException("Invalid size of type " + sizeOfType);
        }
    }

    public OCLMemorySegmentWrapper(OCLDeviceContext deviceContext, long batchSize, Access access, int sizeOfType) {
        this(-1L, deviceContext, batchSize, access, sizeOfType);
    }

    public long toBuffer() {
        return this.bufferId;
    }

    public void setBuffer(XPUBuffer.XPUBufferWrapper bufferWrapper) {
        this.bufferId = bufferWrapper.buffer;
        this.bufferOffset = bufferWrapper.bufferOffset;
        bufferWrapper.bufferOffset += this.bufferSize;
    }

    public long getBufferOffset() {
        return this.bufferOffset;
    }

    public void read(long executionPlanId, Object reference) {
        this.read(executionPlanId, reference, 0L, 0L, null, false);
    }

    private MemorySegment getSegmentWithHeader(Object reference) {
        Object object = reference;
        Objects.requireNonNull(object);
        Object object2 = object;
        int n = 0;
        return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{TornadoNativeArray.class, TornadoCollectionInterface.class, TornadoImagesInterface.class, TornadoMatrixInterface.class, TornadoVolumesInterface.class}, (Object)object2, n)) {
            case 0 -> {
                TornadoNativeArray tornadoNativeArray = (TornadoNativeArray)object2;
                yield tornadoNativeArray.getSegmentWithHeader();
            }
            case 1 -> {
                TornadoCollectionInterface tornadoCollectionInterface = (TornadoCollectionInterface)object2;
                yield tornadoCollectionInterface.getSegmentWithHeader();
            }
            case 2 -> {
                TornadoImagesInterface imagesInterface = (TornadoImagesInterface)object2;
                yield imagesInterface.getSegmentWithHeader();
            }
            case 3 -> {
                TornadoMatrixInterface matrixInterface = (TornadoMatrixInterface)object2;
                yield matrixInterface.getSegmentWithHeader();
            }
            case 4 -> {
                TornadoVolumesInterface volumesInterface = (TornadoVolumesInterface)object2;
                yield volumesInterface.getSegmentWithHeader();
            }
            default -> throw new TornadoMemoryException("Memory Segment not supported: " + String.valueOf(reference.getClass()));
        };
    }

    public int read(long executionPlanId, Object reference, long hostOffset, long partialReadSize, int[] events, boolean useDeps) {
        long numBytes;
        MemorySegment segment = this.getSegmentWithHeader(reference);
        long l = numBytes = this.getSizeSubRegionSize() > 0L ? this.getSizeSubRegionSize() : this.bufferSize;
        int returnEvent = partialReadSize != 0L ? this.deviceContext.readBuffer(executionPlanId, this.toBuffer(), hostOffset, partialReadSize, segment.address(), hostOffset, (int[])(useDeps ? events : null)) : (this.batchSize <= 0L ? this.deviceContext.readBuffer(executionPlanId, this.toBuffer(), this.bufferOffset, numBytes, segment.address(), hostOffset, (int[])(useDeps ? events : null)) : this.deviceContext.readBuffer(executionPlanId, this.toBuffer(), TornadoNativeArray.ARRAY_HEADER, numBytes, segment.address(), hostOffset + TornadoNativeArray.ARRAY_HEADER, (int[])(useDeps ? events : null)));
        return useDeps ? returnEvent : -1;
    }

    public void write(long executionPlanId, Object reference) {
        MemorySegment segment = this.getSegmentWithHeader(reference);
        if (this.batchSize > 0L) {
            throw new TornadoUnsupportedError("[UNSUPPORTED] batch processing for writeBuffer operation", new Object[0]);
        }
        this.deviceContext.writeBuffer(executionPlanId, this.toBuffer(), this.bufferOffset, this.bufferSize, segment.address(), 0L, null);
    }

    public int enqueueRead(long executionPlanId, Object reference, long hostOffset, int[] events, boolean useDeps) {
        MemorySegment segment = this.getSegmentWithHeader(reference);
        if (this.batchSize > 0L) {
            throw new TornadoUnsupportedError("[UNSUPPORTED] batch processing for enqueueReadBuffer operation", new Object[0]);
        }
        int returnEvent = this.deviceContext.enqueueReadBuffer(executionPlanId, this.toBuffer(), this.bufferOffset, this.bufferSize, segment.address(), hostOffset, (int[])(useDeps ? events : null));
        return useDeps ? returnEvent : -1;
    }

    public List<Integer> enqueueWrite(long executionPlanId, Object reference, long batchSize, long hostOffset, int[] events, boolean useDeps) {
        int internalEvent;
        ArrayList<Integer> returnEvents = new ArrayList<Integer>();
        MemorySegment segment = this.getSegmentWithHeader(reference);
        if (batchSize <= 0L) {
            internalEvent = this.deviceContext.enqueueWriteBuffer(executionPlanId, this.toBuffer(), this.bufferOffset, this.bufferSize, segment.address(), hostOffset, (int[])(useDeps ? events : null));
        } else {
            internalEvent = this.deviceContext.enqueueWriteBuffer(executionPlanId, this.toBuffer(), 0L, TornadoNativeArray.ARRAY_HEADER, segment.address(), 0L, (int[])(useDeps ? events : null));
            returnEvents.add(internalEvent);
            internalEvent = this.deviceContext.enqueueWriteBuffer(executionPlanId, this.toBuffer(), this.bufferOffset + TornadoNativeArray.ARRAY_HEADER, this.bufferSize, segment.address(), hostOffset + TornadoNativeArray.ARRAY_HEADER, (int[])(useDeps ? events : null));
        }
        returnEvents.add(internalEvent);
        return returnEvents;
    }

    public void allocate(Object reference, long batchSize, Access access) throws TornadoOutOfMemoryException, TornadoMemoryException {
        MemorySegment segment = this.getSegmentWithHeader(reference);
        if (batchSize <= 0L) {
            this.bufferSize = segment.byteSize();
            this.bufferId = this.deviceContext.getBufferProvider().getOrAllocateBufferWithSize(this.bufferSize, access);
        } else {
            this.bufferSize = batchSize;
            this.bufferId = this.deviceContext.getBufferProvider().getOrAllocateBufferWithSize(this.bufferSize + TornadoNativeArray.ARRAY_HEADER, access);
        }
        if (this.bufferSize <= 0L) {
            throw new TornadoMemoryException("[ERROR] Bytes Allocated <= 0: " + this.bufferSize);
        }
        if (TornadoOptions.FULL_DEBUG) {
            new TornadoLogger().info("allocated: %s", new Object[]{this.toString()});
        }
    }

    public void markAsFreeBuffer() throws TornadoMemoryException {
        TornadoInternalError.guarantee((this.bufferId != -1L ? 1 : 0) != 0, (String)"Fatal error: trying to deallocate an invalid buffer", (Object[])new Object[0]);
        this.deviceContext.getBufferProvider().markBufferReleased(this.bufferId, this.access);
        this.bufferId = -1L;
        this.bufferSize = -1L;
        if (TornadoOptions.FULL_DEBUG) {
            new TornadoLogger().info("deallocated: %s", new Object[]{this.toString()});
        }
    }

    public long deallocate() {
        return this.deviceContext.getBufferProvider().deallocate(this.access);
    }

    public long size() {
        return this.bufferSize;
    }

    public void setSizeSubRegion(long batchSize) {
        this.subregionSize = batchSize;
    }

    public long getSizeSubRegionSize() {
        return this.subregionSize;
    }

    public long getBatchSize() {
        return this.batchSize;
    }

    public void mapOnDeviceMemoryRegion(long executionPlanId, XPUBuffer srcPointer, long offset) {
        if (!(srcPointer instanceof OCLMemorySegmentWrapper)) {
            throw new TornadoRuntimeException("[ERROR] copy pointer must be an instance of OCLMemorySegmentWrapper: " + String.valueOf(srcPointer));
        }
        OCLMemorySegmentWrapper oclMemorySegmentWrapper = (OCLMemorySegmentWrapper)srcPointer;
        long sizeSource = oclMemorySegmentWrapper.bufferSize;
        long sizeDest = this.bufferSize;
        this.bufferId = this.deviceContext.mapOnDeviceMemoryRegion(executionPlanId, this.bufferId, oclMemorySegmentWrapper.bufferId, offset, this.sizeOfType, sizeSource, sizeDest);
    }

    public int getSizeOfType() {
        return this.sizeOfType;
    }
}

