/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.runtime.common;

import java.lang.reflect.Array;
import java.lang.runtime.SwitchBootstraps;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Objects;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.CharArray;
import uk.ac.manchester.tornado.api.types.arrays.DoubleArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.arrays.LongArray;
import uk.ac.manchester.tornado.api.types.arrays.ShortArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.runtime.common.TornadoOptions;
import uk.ac.manchester.tornado.runtime.common.enums.DataTypeSize;
import uk.ac.manchester.tornado.runtime.graph.TornadoExecutionContext;

public class BatchConfiguration {
    private final int totalChunks;
    private final int remainingChunkSize;
    private final short numBytesType;

    public BatchConfiguration(int totalChunks, int remainingChunkSize, short numBytesType) {
        this.totalChunks = totalChunks;
        this.remainingChunkSize = remainingChunkSize;
        this.numBytesType = numBytesType;
    }

    public static BatchConfiguration computeChunkSizes(TornadoExecutionContext context, long batchSize) {
        long totalSize = 0L;
        HashSet<Long> inputSizes = new HashSet<Long>();
        LinkedHashSet<Byte> elementSizes = new LinkedHashSet<Byte>();
        for (Object o : context.getObjects()) {
            if (o.getClass().isArray()) {
                Class<?> componentType = o.getClass().getComponentType();
                DataTypeSize dataTypeSize = DataTypeSize.findDataTypeSize(componentType);
                if (dataTypeSize == null) {
                    throw new TornadoRuntimeException("[UNSUPPORTED] Data type not supported for processing in batches");
                }
                long size = Array.getLength(o);
                totalSize = size * (long)dataTypeSize.getSize();
                elementSizes.add(dataTypeSize.getSize());
                inputSizes.add(totalSize);
                continue;
            }
            if (o instanceof TornadoNativeArray) {
                TornadoNativeArray tornadoNativeArray;
                TornadoNativeArray tornadoNativeArray2 = (TornadoNativeArray)o;
                totalSize = tornadoNativeArray2.getNumBytesOfSegment();
                inputSizes.add(totalSize);
                Objects.requireNonNull(tornadoNativeArray2);
                int n = 0;
                byte elementSize = switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{IntArray.class, FloatArray.class, DoubleArray.class, LongArray.class, ShortArray.class, ByteArray.class, CharArray.class}, (Object)tornadoNativeArray, n)) {
                    case 0 -> DataTypeSize.INT.getSize();
                    case 1 -> DataTypeSize.FLOAT.getSize();
                    case 2 -> DataTypeSize.DOUBLE.getSize();
                    case 3 -> DataTypeSize.LONG.getSize();
                    case 4 -> DataTypeSize.SHORT.getSize();
                    case 5 -> DataTypeSize.BYTE.getSize();
                    case 6 -> DataTypeSize.CHAR.getSize();
                    default -> throw new TornadoRuntimeException("Unsupported array type: " + String.valueOf(o.getClass()));
                };
                elementSizes.add(elementSize);
                continue;
            }
            throw new TornadoRuntimeException("Unsupported type: " + String.valueOf(o.getClass()));
        }
        if (inputSizes.size() > 1) {
            throw new TornadoRuntimeException("[UNSUPPORTED] Input objects with different sizes not currently supported");
        }
        if (elementSizes.size() > 1) {
            throw new TornadoRuntimeException("[UNSUPPORTED] Input objects with different element sizes not currently supported");
        }
        int totalChunks = (int)(totalSize / batchSize);
        int remainingChunkSize = (int)(totalSize % batchSize);
        if (TornadoOptions.DEBUG) {
            System.out.println("Batch Size: " + batchSize);
            System.out.println("Total chunks: " + totalChunks);
            System.out.println("remainingChunkSize: " + remainingChunkSize);
        }
        return new BatchConfiguration(totalChunks, remainingChunkSize, ((Byte)elementSizes.getFirst()).byteValue());
    }

    public int getTotalChunks() {
        return this.totalChunks;
    }

    public int getRemainingChunkSize() {
        return this.remainingChunkSize;
    }

    public short getNumBytesType() {
        return this.numBytesType;
    }
}

