/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.foundation;

import java.util.Random;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestHalfFloats
extends TornadoTestBase {
    public static void convertFP32toFP16v1(KernelContext context, FloatArray wrapX, HalfFloatArray x) {
        int i = context.globalIdx;
        float valInput = wrapX.get(i);
        HalfFloat val = new HalfFloat(valInput);
        x.set(i, val);
    }

    public static void convertFP32toFP16v2(KernelContext context, FloatArray wrapX, HalfFloatArray x) {
        int i = context.globalIdx;
        HalfFloat val = new HalfFloat(wrapX.get(i));
        x.set(i, val);
    }

    public static void convertFP32toFP16Parallel(FloatArray wrapX, HalfFloatArray x) {
        for (int i = 0; i < x.getSize(); ++i) {
            float valInput = wrapX.get(i);
            HalfFloat val = new HalfFloat(valInput);
            x.set(i, val);
        }
    }

    public static void matrixVectorGenericOptimized(KernelContext context, HalfFloatArray x, FloatArray output, HalfFloatArray weights, int dim1, int dim0, int localWorkGroupSize) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (rowId >= dim0) {
            return;
        }
        float sum = TestHalfFloats.matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, weights, dim1);
        if (localId == 0) {
            output.set(rowId, sum);
        }
    }

    public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, HalfFloatArray x, HalfFloatArray w, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        float[] localSum = context.allocateFloatLocalArray(localSize);
        int rowOffset = rowId * n;
        HalfFloat partialSum = new HalfFloat(0.0f);
        for (int j = localId; j < n; j += localSize) {
            int matrixIdx = rowOffset + j;
            HalfFloat mul = HalfFloat.mult((HalfFloat)w.get(matrixIdx), (HalfFloat)x.get(j));
            partialSum = HalfFloat.add((HalfFloat)partialSum, (HalfFloat)mul);
        }
        localSum[localId] = partialSum.getHalfFloatValue();
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        return localSum[0];
    }

    private void matrixVectorSequentialHalf(FloatArray output, HalfFloatArray weights, HalfFloatArray input, int n, int d) {
        for (int i = 0; i < d; ++i) {
            float sum = 0.0f;
            for (int j = 0; j < n; ++j) {
                sum += HalfFloat.mult((HalfFloat)weights.get(i * n + j), (HalfFloat)input.get(j)).getFloat32();
            }
            output.set(i, sum);
        }
    }

    private static void fillRandomDataFp16(HalfFloatArray array, float min, float max, Random random) {
        float range = max - min;
        for (int i = 0; i < array.getSize(); ++i) {
            array.set(i, new HalfFloat(min + (float)random.nextInt() * range));
        }
    }

    @Test
    public void testConvertFP32toFP16v1() throws TornadoExecutionPlanException {
        FloatArray x = new FloatArray(1024);
        HalfFloatArray y = new HalfFloatArray(1024);
        x.init(new Random().nextFloat());
        KernelContext context = new KernelContext();
        TaskGraph tg = new TaskGraph("s0").transferToDevice(1, new Object[]{x}).task("t0", TestHalfFloats::convertFP32toFP16v1, (Object)context, (Object)x, (Object)y).transferToHost(1, new Object[]{y});
        ImmutableTaskGraph immutableTaskGraph = tg.snapshot();
        WorkerGrid1D workerGrid = new WorkerGrid1D(1024);
        workerGrid.setLocalWork(32L, 1L, 1L);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(scheduler).execute();
        }
        for (int i = 0; i < 1024; ++i) {
            Assert.assertEquals((float)x.get(i), (float)y.get(i).getFloat32(), (float)0.001f);
        }
    }

    @Test
    public void testConvertFP32toFP16v2() throws TornadoExecutionPlanException {
        FloatArray x = new FloatArray(1024);
        HalfFloatArray y = new HalfFloatArray(1024);
        x.init(new Random().nextFloat());
        KernelContext context = new KernelContext();
        TaskGraph tg = new TaskGraph("s0").transferToDevice(1, new Object[]{x}).task("t0", TestHalfFloats::convertFP32toFP16v2, (Object)context, (Object)x, (Object)y).transferToHost(1, new Object[]{y});
        ImmutableTaskGraph immutableTaskGraph = tg.snapshot();
        WorkerGrid1D workerGrid = new WorkerGrid1D(1024);
        workerGrid.setLocalWork(32L, 1L, 1L);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(scheduler).execute();
        }
        for (int i = 0; i < 1024; ++i) {
            Assert.assertEquals((float)x.get(i), (float)y.get(i).getFloat32(), (float)0.001f);
        }
    }

    @Test
    public void testConvertFP32toFP16Parallel() throws TornadoExecutionPlanException {
        FloatArray x = new FloatArray(1024);
        HalfFloatArray y = new HalfFloatArray(1024);
        x.init(new Random().nextFloat());
        TaskGraph tg = new TaskGraph("s0").transferToDevice(1, new Object[]{x}).task("t0", TestHalfFloats::convertFP32toFP16Parallel, (Object)x, (Object)y).transferToHost(1, new Object[]{y});
        ImmutableTaskGraph immutableTaskGraph = tg.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 1024; ++i) {
            Assert.assertEquals((float)x.get(i), (float)y.get(i).getFloat32(), (float)0.001f);
        }
    }

    @Test
    public void testMatrixVectorHalfFloatOptimized() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        Random random = new Random(42L);
        int localWorkgroupSize = 64;
        int inputDim = 8192;
        int outputDim = 2048;
        HalfFloatArray input = new HalfFloatArray(inputDim);
        HalfFloatArray weights = new HalfFloatArray(inputDim * outputDim);
        FloatArray output = new FloatArray(outputDim);
        FloatArray outputSeq = new FloatArray(outputDim);
        TestHalfFloats.fillRandomDataFp16(input, -1.0f, 1.0f, random);
        TestHalfFloats.fillRandomDataFp16(weights, -0.1f, 0.1f, random);
        WorkerGrid1D hybridWorker = new WorkerGrid1D(outputDim * localWorkgroupSize);
        hybridWorker.setLocalWork((long)localWorkgroupSize, 1L, 1L);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)hybridWorker);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input, weights}).task("t0", TestHalfFloats::matrixVectorGenericOptimized, (Object)new KernelContext(), (Object)input, (Object)output, (Object)weights, (Object)inputDim, (Object)outputDim, (Object)localWorkgroupSize).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan tornadoExecutor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            tornadoExecutor.withGridScheduler(scheduler).execute();
        }
        this.matrixVectorSequentialHalf(outputSeq, weights, input, inputDim, outputDim);
        for (int i = 0; i < output.getSize(); ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.1f);
        }
    }
}

