/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.drivers.opencl.scheduler;

import java.util.Arrays;
import uk.ac.manchester.tornado.drivers.opencl.OCLDeviceContext;
import uk.ac.manchester.tornado.drivers.opencl.OCLTargetDevice;
import uk.ac.manchester.tornado.drivers.opencl.scheduler.OCLKernelScheduler;
import uk.ac.manchester.tornado.runtime.tasks.meta.TaskDataContext;

public class OCLNVIDIAGPUScheduler
extends OCLKernelScheduler {
    private static final int WARP_SIZE = 32;
    private boolean ADJUST_IRREGULAR = false;
    private final long[] maxWorkItemSizes;

    public OCLNVIDIAGPUScheduler(OCLDeviceContext context) {
        super(context);
        OCLTargetDevice device = context.getDevice();
        this.maxWorkItemSizes = device.getDeviceMaxWorkItemSizes();
    }

    @Override
    public void calculateGlobalWork(TaskDataContext meta, long batchThreads) {
        long[] globalWork = meta.getGlobalWork();
        for (int i = 0; i < meta.getDims(); ++i) {
            long value;
            long l = value = batchThreads <= 0L ? (long)meta.getDomain().get(i).cardinality() : batchThreads;
            if (this.ADJUST_IRREGULAR && value % 32L != 0L) {
                value = (value / 32L + 1L) * 32L;
            }
            globalWork[i] = value;
        }
    }

    @Override
    public void calculateLocalWork(TaskDataContext meta) {
        long[] localWork = meta.initLocalWork();
        switch (meta.getDims()) {
            case 3: {
                localWork[2] = this.calculateGroupSize(this.calculateEffectiveMaxWorkItemSizes(meta)[2], meta.getGlobalWork()[2], 3);
                localWork[1] = this.calculateGroupSize(this.calculateEffectiveMaxWorkItemSizes(meta)[1], meta.getGlobalWork()[1], 3);
                localWork[0] = this.calculateGroupSize(this.calculateEffectiveMaxWorkItemSizes(meta)[0], meta.getGlobalWork()[0], 3);
                break;
            }
            case 2: {
                localWork[1] = this.calculateGroupSize(this.calculateEffectiveMaxWorkItemSizes(meta)[1], meta.getGlobalWork()[1], 2);
                localWork[0] = this.calculateGroupSize(this.calculateEffectiveMaxWorkItemSizes(meta)[0], meta.getGlobalWork()[0], 2);
                break;
            }
            case 1: {
                localWork[0] = this.calculateGroupSize(this.calculateEffectiveMaxWorkItemSizes(meta)[0], meta.getGlobalWork()[0], 1);
                break;
            }
        }
    }

    @Override
    public void checkAndAdaptLocalWork(TaskDataContext meta) {
        long[] localWork = meta.getLocalWork();
        if (localWork == null) {
            return;
        }
        switch (meta.getDims()) {
            case 3: {
                localWork[2] = this.checkAndAdaptLocalDimensions(localWork)[2];
                localWork[1] = this.checkAndAdaptLocalDimensions(localWork)[1];
                localWork[0] = this.checkAndAdaptLocalDimensions(localWork)[0];
                break;
            }
            case 2: {
                localWork[1] = this.checkAndAdaptLocalDimensions(localWork)[1];
                localWork[0] = this.checkAndAdaptLocalDimensions(localWork)[0];
                break;
            }
            case 1: {
                localWork[0] = this.checkAndAdaptLocalDimensions(localWork)[0];
                break;
            }
        }
    }

    private long[] checkAndAdaptLocalDimensions(long[] localWorkGroups) {
        long[] blockMaxWorkGroupSize = this.deviceContext.getDevice().getDeviceMaxWorkGroupSize();
        long maxWorkGroupSize = Arrays.stream(blockMaxWorkGroupSize).sum();
        long totalThreads = Arrays.stream(localWorkGroups).reduce(1L, (a, b) -> a * b);
        if (totalThreads > maxWorkGroupSize) {
            return this.adaptLocalDimensions(localWorkGroups, maxWorkGroupSize);
        }
        return localWorkGroups;
    }

    private long[] adaptLocalDimensions(long[] localWorkGroups, long maxWorkGroupSize) {
        long[] newLocalWorkGroup = new long[localWorkGroups.length];
        switch (localWorkGroups.length) {
            case 3: {
                newLocalWorkGroup[0] = localWorkGroups[0];
                newLocalWorkGroup[1] = localWorkGroups[1];
                newLocalWorkGroup[2] = maxWorkGroupSize / (newLocalWorkGroup[0] * newLocalWorkGroup[1]);
                break;
            }
            case 2: {
                newLocalWorkGroup[1] = maxWorkGroupSize / localWorkGroups[0];
                break;
            }
            case 1: {
                newLocalWorkGroup[0] = maxWorkGroupSize;
                break;
            }
        }
        return newLocalWorkGroup;
    }

    private int calculateGroupSize(long maxBlockSize, long globalWorkSize, int dim) {
        int value;
        if (maxBlockSize == globalWorkSize) {
            maxBlockSize /= 4L;
        }
        if ((value = (int)Math.min(maxBlockSize, globalWorkSize)) == 0) {
            return 1;
        }
        while (globalWorkSize % (long)value != 0L) {
            --value;
        }
        if (value >= 32 && dim > 1) {
            value /= 2;
        }
        return value;
    }

    private long[] calculateEffectiveMaxWorkItemSizes(TaskDataContext metaData) {
        long[] localWorkGroups = new long[]{1L, 1L, 1L};
        if (metaData.getDims() == 1) {
            localWorkGroups[0] = this.maxWorkItemSizes[0];
        } else {
            for (int i = 0; i < metaData.getDims(); ++i) {
                localWorkGroups[i] = (long)Math.sqrt(this.maxWorkItemSizes[i]);
            }
        }
        return localWorkGroups;
    }
}

