/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.quantization;

import java.util.Random;
import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.TornadoExecutionResult;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.Int8Array;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.utils.QuantizationUtils;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class QuantizationTests
extends TornadoTestBase {
    private static final int TILE_SIZE = 128;

    public static void performDP4A(Int8Array a, Int8Array b, IntArray result) {
        for (int i = 0; i < result.getSize(); ++i) {
            int dot = QuantizationUtils.dp4a((Int8Array)a, (long)(i * 4), (Int8Array)b, (long)(i * 4), (int)0);
            result.set(i, dot);
        }
    }

    public static void reductionCalculateMax(KernelContext context, FloatArray max, FloatArray x, FloatArray x_scale, FloatArray inv_scale, int localMemSize, int arraySize) {
        int gid = context.globalIdx;
        int lid = context.localIdx;
        int groupId = context.groupIdx;
        int groupSize = context.localGroupSizeX;
        float[] localX = context.allocateFloatLocalArray(localMemSize);
        localX[lid] = gid < arraySize ? TornadoMath.abs((float)x.get(gid)) : 0.0f;
        for (int stride = groupSize / 2; stride > 0; stride /= 2) {
            context.localBarrier();
            if (lid >= stride) continue;
            localX[lid] = TornadoMath.max((float)localX[lid], (float)localX[lid + stride]);
        }
        if (lid == 0) {
            max.set(groupId, localX[0]);
        }
        if (gid == 0) {
            int numGroups = (arraySize + groupSize - 1) / groupSize;
            float max_abs_val = 0.0f;
            for (int i = 0; i < numGroups; ++i) {
                max_abs_val = TornadoMath.max((float)max_abs_val, (float)max.get(i));
            }
            max.set(0, max_abs_val);
            float scale = max_abs_val == 0.0f ? 1.0f : max_abs_val / 127.0f;
            inv_scale.set(0, 1.0f / scale);
            x_scale.set(0, scale);
        }
    }

    public static void quantizeKernelContext(KernelContext context, FloatArray x, FloatArray inv_scale, Int8Array x_quant) {
        int gid = context.globalIdx;
        float scale = inv_scale.get(0);
        x_quant.set(gid, (byte)TornadoMath.floor((float)(x.get(gid) * scale + 0.5f)));
    }

    private static void quantizeFloatArray(FloatArray x, Int8Array x_quant, FloatArray x_scale) {
        float max_abs_val = 0.0f;
        for (int i = 0; i < x.getSize(); ++i) {
            max_abs_val = TornadoMath.max((float)max_abs_val, (float)TornadoMath.abs((float)x.get(i)));
        }
        float scale = max_abs_val == 0.0f ? 1.0f : max_abs_val / 127.0f;
        float inv_scale = 1.0f / scale;
        for (int i = 0; i < x.getSize(); ++i) {
            x_quant.set(i, (byte)TornadoMath.floor((float)(x.get(i) * inv_scale + 0.5f)));
        }
        x_scale.set(0, scale);
    }

    public static float matrixVectorRowMajorOptimizedDP4A(KernelContext context, int localSize, Int8Array w_quant, Int8Array x_quant, int n, float w_scale, float x_scale) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (localId >= localSize) {
            return 0.0f;
        }
        int[] localSum = context.allocateIntLocalArray(localSize);
        int rowOffset = rowId * n;
        int partialSum = 0;
        for (int j = localId * 4; j < n; j += localSize * 4) {
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(rowOffset + j), (Int8Array)x_quant, (long)j, (int)partialSum);
        }
        localSum[localId] = partialSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        if (localId == 0) {
            return QuantizationUtils.dequantizeFusedResult((int)localSum[0], (float)w_scale, (float)x_scale);
        }
        return 0.0f;
    }

    public static void matrixVectorGenericDP4A(KernelContext context, Int8Array w_quant, Int8Array x_quant, FloatArray output, int n, int d, int localWorkGroupSize, FloatArray w_scale, FloatArray x_scale) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (rowId >= d) {
            return;
        }
        float sum = QuantizationTests.matrixVectorRowMajorOptimizedDP4A(context, localWorkGroupSize, w_quant, x_quant, n, w_scale.get(0), x_scale.get(0));
        if (localId == 0) {
            output.set(rowId, sum);
        }
    }

    public static void matrixVectorGenericLocalMemory(KernelContext context, Int8Array w_quant, Int8Array x_quant, FloatArray output, int n, int d, int localWorkGroupSize, FloatArray w_scale, FloatArray x_scale) {
        int rowId = context.groupIdx;
        if (rowId >= d) {
            return;
        }
        int intSum = QuantizationTests.matrixVectorDP4ALocalMemory(context, localWorkGroupSize, w_quant, x_quant, n);
        if (context.localIdx == 0) {
            float finalValue = (float)intSum * w_scale.get(0) * x_scale.get(0);
            output.set(rowId, finalValue);
        }
    }

    public static int matrixVectorDP4ALocalMemory(KernelContext context, int localSize, Int8Array w_quant, Int8Array x_quant, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        byte[] x_tile = context.allocateByteLocalArray(128);
        int[] localSum = context.allocateIntLocalArray(localSize);
        int rowOffset = rowId * n;
        int totalSum = 0;
        for (int tileStart = 0; tileStart < n; tileStart += 128) {
            for (int i = localId; i < 128 && tileStart + i < n; i += localSize) {
                x_tile[i] = x_quant.get(tileStart + i);
            }
            context.localBarrier();
            int partialSum = 0;
            for (int j = localId * 4; j < 128; j += localSize * 4) {
                if (tileStart + j >= n) continue;
                partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(rowOffset + tileStart + j), (byte[])x_tile, (long)j, (int)partialSum);
            }
            totalSum += partialSum;
            context.localBarrier();
        }
        localSum[localId] = totalSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        return localSum[0];
    }

    public static void matrixVectorGenericPacked(KernelContext context, Int8Array w_quant, Int8Array x_quant, FloatArray output, int n, int d, int localWorkGroupSize, FloatArray w_scale, FloatArray x_scale) {
        int rowId = context.groupIdx;
        if (rowId >= d) {
            return;
        }
        int intSum = QuantizationTests.matrixVectorPacked(context, localWorkGroupSize, w_quant, x_quant, n);
        if (context.localIdx == 0) {
            float finalValue = (float)intSum * w_scale.get(0) * x_scale.get(0);
            output.set(rowId, finalValue);
        }
    }

    public static int matrixVectorPacked(KernelContext context, int localSize, Int8Array w_quant, Int8Array x_quant, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int[] x_tile_packed = context.allocateIntLocalArray(32);
        int[] localSum = context.allocateIntLocalArray(localSize);
        int rowOffset = rowId * n;
        int totalSum = 0;
        int nPacked = n / 4;
        for (int tileStart = 0; tileStart < nPacked; tileStart += 32) {
            int tileSize = Math.min(32, nPacked - tileStart);
            for (int i = localId; i < tileSize; i += localSize) {
                int packed;
                int globalIdx = (tileStart + i) * 4;
                if (globalIdx >= n) continue;
                byte b0 = x_quant.get(globalIdx);
                byte b1 = globalIdx + 1 < n ? x_quant.get(globalIdx + 1) : (byte)0;
                byte b2 = globalIdx + 2 < n ? x_quant.get(globalIdx + 2) : (byte)0;
                byte b3 = globalIdx + 3 < n ? x_quant.get(globalIdx + 3) : (byte)0;
                x_tile_packed[i] = packed = b0 & 0xFF | (b1 & 0xFF) << 8 | (b2 & 0xFF) << 16 | (b3 & 0xFF) << 24;
            }
            context.localBarrier();
            int partialSum = 0;
            for (int i = localId; i < tileSize; i += localSize) {
                int w_globalIdx = rowOffset + (tileStart + i) * 4;
                if (w_globalIdx >= rowOffset + n) continue;
                byte b0 = w_quant.get(w_globalIdx);
                byte b1 = w_globalIdx + 1 < rowOffset + n ? w_quant.get(w_globalIdx + 1) : (byte)0;
                byte b2 = w_globalIdx + 2 < rowOffset + n ? w_quant.get(w_globalIdx + 2) : (byte)0;
                byte b3 = w_globalIdx + 3 < rowOffset + n ? w_quant.get(w_globalIdx + 3) : (byte)0;
                int w_packed = b0 & 0xFF | (b1 & 0xFF) << 8 | (b2 & 0xFF) << 16 | (b3 & 0xFF) << 24;
                partialSum = QuantizationUtils.dp4a_packed((int)w_packed, (int)x_tile_packed[i], (int)partialSum);
            }
            totalSum += partialSum;
            context.localBarrier();
        }
        localSum[localId] = totalSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        return localSum[0];
    }

    private void matrixVectorSequential(FloatArray output, FloatArray weights, FloatArray input, int n, int d) {
        for (int i = 0; i < d; ++i) {
            float sum = 0.0f;
            for (int j = 0; j < n; ++j) {
                sum += weights.get(i * n + j) * input.get(j);
            }
            output.set(i, sum);
        }
    }

    private static void fillRandomData(FloatArray array, float min, float max, Random random) {
        float range = max - min;
        for (int i = 0; i < array.getSize(); ++i) {
            array.set(i, min + random.nextFloat() * range);
        }
    }

    public static void matrixVectorGenericVectorized(KernelContext context, FloatArray x, FloatArray output, Int8Array weightsQ, HalfFloatArray weightScales, int dim1, int dim0, int localWorkGroupSize) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (rowId >= dim0) {
            return;
        }
        float sum = QuantizationTests.matrixVectorRowMajorOptimizedVectorized(context, localWorkGroupSize, x, weightsQ, weightScales, dim1);
        if (localId == 0) {
            output.set(rowId, sum);
        }
    }

    public static float matrixVectorRowMajorOptimizedVectorized(KernelContext context, int localSize, FloatArray x, Int8Array weightsQ, HalfFloatArray weightScales, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int blockSize = 32;
        float[] localSums = context.allocateFloatLocalArray(localSize);
        int rowOffset = rowId * n;
        int scalesRowOffset = rowId * (n / blockSize);
        float partialSum1 = 0.0f;
        float partialSum2 = 0.0f;
        float partialSum3 = 0.0f;
        float partialSum4 = 0.0f;
        for (int j = localId * 4; j < n - 3; j += localSize * 4) {
            int blockIdx = j / blockSize;
            float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32();
            partialSum1 += (float)weightsQ.get(rowOffset + j) * scale * x.get(j);
            partialSum2 += (float)weightsQ.get(rowOffset + j + 1) * scale * x.get(j + 1);
            partialSum3 += (float)weightsQ.get(rowOffset + j + 2) * scale * x.get(j + 2);
            partialSum4 += (float)weightsQ.get(rowOffset + j + 3) * scale * x.get(j + 3);
        }
        float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4;
        for (int j = n / 4 * 4 + localId; j < n; j += localSize) {
            int blockIdx = j / blockSize;
            float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32();
            partialSum += (float)weightsQ.get(rowOffset + j) * scale * x.get(j);
        }
        localSums[localId] = partialSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSums[n2] = localSums[n2] + localSums[localId + stride];
            }
            context.localBarrier();
        }
        return localSums[0];
    }

    public static void matrixVectorGenericDP4A4Way(KernelContext context, Int8Array w_quant, Int8Array x_quant, FloatArray output, int n, int d, int localWorkGroupSize, FloatArray w_scale, FloatArray x_scale) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (rowId >= d) {
            return;
        }
        float sum = QuantizationTests.matrixVectorDP4A4Way(context, localWorkGroupSize, w_quant, x_quant, n, w_scale.get(0), x_scale.get(0));
        if (localId == 0) {
            output.set(rowId, sum);
        }
    }

    public static float matrixVectorDP4A4Way(KernelContext context, int localSize, Int8Array w_quant, Int8Array x_quant, int n, float w_scale, float x_scale) {
        int remaining;
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int[] localSum = context.allocateIntLocalArray(localSize);
        int rowOffset = rowId * n;
        float combinedScale = w_scale * x_scale;
        int partialSum = 0;
        int stride = localSize * 16;
        int limit = n - 15;
        for (int j = localId * 16; j < limit; j += stride) {
            int wOffset = rowOffset + j;
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)wOffset, (Int8Array)x_quant, (long)j, (int)partialSum);
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(wOffset + 4), (Int8Array)x_quant, (long)(j + 4), (int)partialSum);
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(wOffset + 8), (Int8Array)x_quant, (long)(j + 8), (int)partialSum);
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(wOffset + 12), (Int8Array)x_quant, (long)(j + 12), (int)partialSum);
        }
        for (int j = remaining = n / 16 * 16 + localId * 4; j < n; j += localSize * 4) {
            if (j + 3 >= n) continue;
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(rowOffset + j), (Int8Array)x_quant, (long)j, (int)partialSum);
        }
        localSum[localId] = partialSum;
        context.localBarrier();
        for (int stride2 = localSize / 2; stride2 > 0; stride2 >>= 1) {
            if (localId < stride2) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride2];
            }
            context.localBarrier();
        }
        return (float)localSum[0] * combinedScale;
    }

    public static void quantizeWeightsToQ8(FloatArray weightsFP32, Int8Array outQ, HalfFloatArray outScales, int rows, int cols, int blockSize) {
        if (cols % blockSize != 0) {
            throw new IllegalArgumentException("cols must be multiple of BLOCK_SIZE=" + blockSize);
        }
        int blocksPerRow = cols / blockSize;
        for (int r = 0; r < rows; ++r) {
            int rowBase = r * cols;
            for (int b = 0; b < blocksPerRow; ++b) {
                int blockStart = rowBase + b * blockSize;
                float maxAbs = 0.0f;
                for (int i = 0; i < blockSize; ++i) {
                    float v = weightsFP32.get(blockStart + i);
                    float a = Math.abs(v);
                    if (!(a > maxAbs)) continue;
                    maxAbs = a;
                }
                float scale = maxAbs == 0.0f ? 0.0f : maxAbs / 127.0f;
                int globalBlockIdx = r * blocksPerRow + b;
                outScales.set(globalBlockIdx, new HalfFloat(scale));
                float inv = scale == 0.0f ? 0.0f : 1.0f / scale;
                for (int i = 0; i < blockSize; ++i) {
                    float val = weightsFP32.get(blockStart + i);
                    int q = Math.round(val * inv);
                    if (q > 127) {
                        q = 127;
                    } else if (q < -127) {
                        q = -127;
                    }
                    outQ.set(blockStart + i, (byte)q);
                }
            }
        }
    }

    @Test
    public void testDP4A() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.OPENCL);
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        int N = 512;
        Int8Array a = new Int8Array(N);
        Int8Array b = new Int8Array(N);
        IntArray result = new IntArray(N / 4);
        IntArray resultSeq = new IntArray(N / 4);
        Random r = new Random();
        IntStream.range(0, N).sequential().forEach(i -> {
            a.set(i, (byte)r.nextInt());
            b.set(i, (byte)r.nextInt());
        });
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("t0", QuantizationTests::performDP4A, (Object)a, (Object)b, (Object)result).transferToHost(1, new Object[]{result});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan tornadoExecutor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            tornadoExecutor.execute();
        }
        QuantizationTests.performDP4A(a, b, resultSeq);
        for (int i2 = 0; i2 < result.getSize(); ++i2) {
            Assert.assertEquals((double)resultSeq.get(i2), (double)result.get(i2), (double)1.0E-4);
        }
    }

    @Test
    public void testQuantization() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        Random random = new Random(42L);
        int local_workgroup_size = 128;
        int inputDim = 8192;
        FloatArray input = new FloatArray(inputDim);
        QuantizationTests.fillRandomData(input, -1.0f, 1.0f, random);
        FloatArray x_scale = new FloatArray(1);
        FloatArray x_max = new FloatArray(1);
        FloatArray inv_scale = new FloatArray(1);
        Int8Array x_quant = new Int8Array(input.getSize());
        WorkerGrid1D worker = new WorkerGrid1D(inputDim);
        GridScheduler scheduler = new GridScheduler();
        scheduler.addWorkerGrid("s0_quant.scales", (WorkerGrid)worker);
        scheduler.addWorkerGrid("s0_quant.quantize", (WorkerGrid)worker);
        worker.setLocalWork((long)local_workgroup_size, 1L, 1L);
        int numGroups = (inputDim + local_workgroup_size - 1) / local_workgroup_size;
        TaskGraph taskGraphQ = new TaskGraph("s0_quant").transferToDevice(0, new Object[]{input, x_quant, x_scale, x_max, inv_scale}).task("scales", QuantizationTests::reductionCalculateMax, (Object)new KernelContext(), (Object)x_max, (Object)input, (Object)x_scale, (Object)inv_scale, (Object)local_workgroup_size, (Object)inputDim).task("quantize", QuantizationTests::quantizeKernelContext, (Object)new KernelContext(), (Object)input, (Object)inv_scale, (Object)x_quant).transferToHost(1, new Object[]{x_quant, x_scale});
        ImmutableTaskGraph immutableTaskGraphQ = taskGraphQ.snapshot();
        try (TornadoExecutionPlan tornadoExecutor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraphQ});){
            tornadoExecutor.withGridScheduler(scheduler).execute();
        }
        for (int i = 0; i < input.getSize(); ++i) {
            Assert.assertEquals((float)input.get(i), (float)((float)x_quant.get(i) * x_scale.get(0)), (float)0.01f);
        }
    }

    @Test
    public void testMatrixVectorDP4AKernelContext() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.OPENCL);
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        Random random = new Random(42L);
        int local_workgroup_size = 32;
        int inputDim = 8192;
        int outputDim = 2048;
        FloatArray input = new FloatArray(inputDim);
        FloatArray weights = new FloatArray(inputDim * outputDim);
        FloatArray output = new FloatArray(outputDim);
        FloatArray outputSeq = new FloatArray(outputDim);
        QuantizationTests.fillRandomData(input, -1.0f, 1.0f, random);
        QuantizationTests.fillRandomData(weights, -0.1f, 0.1f, random);
        Int8Array w_quant = new Int8Array(weights.getSize());
        FloatArray w_scale = new FloatArray(1);
        QuantizationTests.quantizeFloatArray(weights, w_quant, w_scale);
        FloatArray x_scale = new FloatArray(1);
        int maxNumGroups = (inputDim + local_workgroup_size - 1) / local_workgroup_size;
        FloatArray x_max = new FloatArray(maxNumGroups);
        FloatArray inv_scale = new FloatArray(1);
        Int8Array x_quant = new Int8Array(input.getSize());
        WorkerGrid1D workerQuant = new WorkerGrid1D(inputDim);
        WorkerGrid1D workerDp4a = new WorkerGrid1D(local_workgroup_size * inputDim);
        GridScheduler schedulerDp4a = new GridScheduler();
        schedulerDp4a.addWorkerGrid("s0_quant_kc.scales", (WorkerGrid)workerQuant);
        schedulerDp4a.addWorkerGrid("s0_quant_kc.quantize", (WorkerGrid)workerQuant);
        schedulerDp4a.addWorkerGrid("s0_quant_kc.dp4amatvec", (WorkerGrid)workerDp4a);
        workerQuant.setLocalWork((long)local_workgroup_size, 1L, 1L);
        workerDp4a.setLocalWork((long)local_workgroup_size, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0_quant_kc").transferToDevice(0, new Object[]{input, x_quant, x_scale, x_max, inv_scale, w_quant, w_scale}).task("scales", QuantizationTests::reductionCalculateMax, (Object)new KernelContext(), (Object)x_max, (Object)input, (Object)x_scale, (Object)inv_scale, (Object)local_workgroup_size, (Object)inputDim).task("quantize", QuantizationTests::quantizeKernelContext, (Object)new KernelContext(), (Object)input, (Object)inv_scale, (Object)x_quant).task("dp4amatvec", QuantizationTests::matrixVectorGenericDP4A, (Object)new KernelContext(), (Object)w_quant, (Object)x_quant, (Object)output, (Object)inputDim, (Object)outputDim, (Object)local_workgroup_size, (Object)w_scale, (Object)x_scale).transferToHost(1, new Object[]{output, x_quant, x_scale, x_max, inv_scale});
        ImmutableTaskGraph immutabletaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan tornadoExecutor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutabletaskGraph});){
            TornadoExecutionResult result = tornadoExecutor.withGridScheduler(schedulerDp4a).execute();
            System.out.println("Execution result: " + String.valueOf(result));
        }
        this.matrixVectorSequential(outputSeq, weights, input, inputDim, outputDim);
        for (int i = 0; i < output.getSize(); ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.1f);
        }
    }

    @Test
    public void testMatrixVectorDP4AKernelLocalMemory() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.OPENCL);
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        Random random = new Random(42L);
        int local_workgroup_size = 32;
        int inputDim = 8192;
        int outputDim = 2048;
        FloatArray input = new FloatArray(inputDim);
        FloatArray weights = new FloatArray(inputDim * outputDim);
        FloatArray output = new FloatArray(outputDim);
        FloatArray outputSeq = new FloatArray(outputDim);
        QuantizationTests.fillRandomData(input, -1.0f, 1.0f, random);
        QuantizationTests.fillRandomData(weights, -0.1f, 0.1f, random);
        Int8Array w_quant = new Int8Array(weights.getSize());
        FloatArray w_scale = new FloatArray(1);
        QuantizationTests.quantizeFloatArray(weights, w_quant, w_scale);
        FloatArray x_scale = new FloatArray(1);
        int maxNumGroups = (inputDim + local_workgroup_size - 1) / local_workgroup_size;
        FloatArray x_max = new FloatArray(maxNumGroups);
        FloatArray inv_scale = new FloatArray(1);
        Int8Array x_quant = new Int8Array(input.getSize());
        WorkerGrid1D workerQuant = new WorkerGrid1D(inputDim);
        WorkerGrid1D workerDp4a = new WorkerGrid1D(local_workgroup_size * inputDim);
        GridScheduler schedulerDp4a = new GridScheduler();
        schedulerDp4a.addWorkerGrid("s0_quant_kc.scales", (WorkerGrid)workerQuant);
        schedulerDp4a.addWorkerGrid("s0_quant_kc.quantize", (WorkerGrid)workerQuant);
        schedulerDp4a.addWorkerGrid("s0_quant_kc.dp4amatvec", (WorkerGrid)workerDp4a);
        workerQuant.setLocalWork((long)local_workgroup_size, 1L, 1L);
        workerDp4a.setLocalWork((long)local_workgroup_size, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0_quant_kc").transferToDevice(0, new Object[]{input, x_quant, x_scale, x_max, inv_scale, w_quant, w_scale}).task("scales", QuantizationTests::reductionCalculateMax, (Object)new KernelContext(), (Object)x_max, (Object)input, (Object)x_scale, (Object)inv_scale, (Object)local_workgroup_size, (Object)inputDim).task("quantize", QuantizationTests::quantizeKernelContext, (Object)new KernelContext(), (Object)input, (Object)inv_scale, (Object)x_quant).task("dp4amatvec", QuantizationTests::matrixVectorGenericLocalMemory, (Object)new KernelContext(), (Object)w_quant, (Object)x_quant, (Object)output, (Object)inputDim, (Object)outputDim, (Object)local_workgroup_size, (Object)w_scale, (Object)x_scale).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutabletaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan tornadoExecutor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutabletaskGraph});){
            tornadoExecutor.withGridScheduler(schedulerDp4a).execute();
        }
        this.matrixVectorSequential(outputSeq, weights, input, inputDim, outputDim);
        for (int i = 0; i < output.getSize(); ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.1f);
        }
    }

    @Test
    public void testMatrixVectorDP4AKernelPacked() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.OPENCL);
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        Random random = new Random(42L);
        int local_workgroup_size = 32;
        int inputDim = 8192;
        int outputDim = 2048;
        FloatArray input = new FloatArray(inputDim);
        FloatArray weights = new FloatArray(inputDim * outputDim);
        FloatArray output = new FloatArray(outputDim);
        FloatArray outputSeq = new FloatArray(outputDim);
        QuantizationTests.fillRandomData(input, -1.0f, 1.0f, random);
        QuantizationTests.fillRandomData(weights, -0.1f, 0.1f, random);
        Int8Array w_quant = new Int8Array(weights.getSize());
        FloatArray w_scale = new FloatArray(1);
        QuantizationTests.quantizeFloatArray(weights, w_quant, w_scale);
        FloatArray x_scale = new FloatArray(1);
        int maxNumGroups = (inputDim + local_workgroup_size - 1) / local_workgroup_size;
        FloatArray x_max = new FloatArray(maxNumGroups);
        FloatArray inv_scale = new FloatArray(1);
        Int8Array x_quant = new Int8Array(input.getSize());
        WorkerGrid1D workerQuant = new WorkerGrid1D(inputDim);
        WorkerGrid1D workerDp4a = new WorkerGrid1D(local_workgroup_size * inputDim);
        GridScheduler schedulerDp4a = new GridScheduler();
        schedulerDp4a.addWorkerGrid("s0_quant_kc.scales", (WorkerGrid)workerQuant);
        schedulerDp4a.addWorkerGrid("s0_quant_kc.quantize", (WorkerGrid)workerQuant);
        schedulerDp4a.addWorkerGrid("s0_quant_kc.dp4amatvec", (WorkerGrid)workerDp4a);
        workerQuant.setLocalWork((long)local_workgroup_size, 1L, 1L);
        workerDp4a.setLocalWork((long)local_workgroup_size, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0_quant_kc").transferToDevice(0, new Object[]{input, x_quant, x_scale, x_max, inv_scale, w_quant, w_scale}).task("scales", QuantizationTests::reductionCalculateMax, (Object)new KernelContext(), (Object)x_max, (Object)input, (Object)x_scale, (Object)inv_scale, (Object)local_workgroup_size, (Object)inputDim).task("quantize", QuantizationTests::quantizeKernelContext, (Object)new KernelContext(), (Object)input, (Object)inv_scale, (Object)x_quant).task("dp4amatvec", QuantizationTests::matrixVectorGenericPacked, (Object)new KernelContext(), (Object)w_quant, (Object)x_quant, (Object)output, (Object)inputDim, (Object)outputDim, (Object)local_workgroup_size, (Object)w_scale, (Object)x_scale).transferToHost(1, new Object[]{output, x_quant, x_scale, inv_scale, x_max});
        ImmutableTaskGraph immutabletaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan tornadoExecutor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutabletaskGraph});){
            tornadoExecutor.withGridScheduler(schedulerDp4a).execute();
        }
        this.matrixVectorSequential(outputSeq, weights, input, inputDim, outputDim);
        for (int i = 0; i < output.getSize(); ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.1f);
        }
    }

    @Test
    public void testMatrixVector4WayDP4AKernel() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.OPENCL);
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        Random random = new Random(42L);
        int local_workgroup_size = 32;
        int inputDim = 8192;
        int outputDim = 2048;
        FloatArray input = new FloatArray(inputDim);
        FloatArray weights = new FloatArray(inputDim * outputDim);
        FloatArray output = new FloatArray(outputDim);
        FloatArray outputSeq = new FloatArray(outputDim);
        QuantizationTests.fillRandomData(input, -1.0f, 1.0f, random);
        QuantizationTests.fillRandomData(weights, -0.1f, 0.1f, random);
        Int8Array w_quant = new Int8Array(weights.getSize());
        FloatArray w_scale = new FloatArray(1);
        QuantizationTests.quantizeFloatArray(weights, w_quant, w_scale);
        FloatArray x_scale = new FloatArray(1);
        int maxNumGroups = (inputDim + local_workgroup_size - 1) / local_workgroup_size;
        FloatArray x_max = new FloatArray(maxNumGroups);
        FloatArray inv_scale = new FloatArray(1);
        Int8Array x_quant = new Int8Array(input.getSize());
        WorkerGrid1D workerQuant = new WorkerGrid1D(inputDim);
        WorkerGrid1D workerDp4a = new WorkerGrid1D(local_workgroup_size * inputDim);
        GridScheduler schedulerDp4a = new GridScheduler();
        schedulerDp4a.addWorkerGrid("s0_quant_kc.scales", (WorkerGrid)workerQuant);
        schedulerDp4a.addWorkerGrid("s0_quant_kc.quantize", (WorkerGrid)workerQuant);
        schedulerDp4a.addWorkerGrid("s0_quant_kc.dp4amatvec", (WorkerGrid)workerDp4a);
        workerQuant.setLocalWork((long)local_workgroup_size, 1L, 1L);
        workerDp4a.setLocalWork((long)local_workgroup_size, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0_quant_kc").transferToDevice(0, new Object[]{input, x_quant, x_scale, x_max, inv_scale, w_quant, w_scale}).task("scales", QuantizationTests::reductionCalculateMax, (Object)new KernelContext(), (Object)x_max, (Object)input, (Object)x_scale, (Object)inv_scale, (Object)local_workgroup_size, (Object)inputDim).task("quantize", QuantizationTests::quantizeKernelContext, (Object)new KernelContext(), (Object)input, (Object)inv_scale, (Object)x_quant).task("dp4amatvec", QuantizationTests::matrixVectorGenericDP4A4Way, (Object)new KernelContext(), (Object)w_quant, (Object)x_quant, (Object)output, (Object)inputDim, (Object)outputDim, (Object)local_workgroup_size, (Object)w_scale, (Object)x_scale).transferToHost(1, new Object[]{output, x_quant, x_scale, inv_scale, x_max});
        ImmutableTaskGraph immutabletaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan tornadoExecutor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutabletaskGraph});){
            tornadoExecutor.withGridScheduler(schedulerDp4a).execute();
        }
        this.matrixVectorSequential(outputSeq, weights, input, inputDim, outputDim);
        for (int i = 0; i < output.getSize(); ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.1f);
        }
    }

    @Test
    public void testMatrixVectorVectorized() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        Random random = new Random(42L);
        int localWorkgroupSize = 64;
        int blockSize = 32;
        int inputDim = 8192;
        int outputDim = 2048;
        FloatArray input = new FloatArray(inputDim);
        FloatArray weights = new FloatArray(inputDim * outputDim);
        FloatArray output = new FloatArray(outputDim);
        FloatArray outputSeq = new FloatArray(outputDim);
        QuantizationTests.fillRandomData(input, -1.0f, 1.0f, random);
        QuantizationTests.fillRandomData(weights, -0.1f, 0.1f, random);
        Int8Array w_quant = new Int8Array(weights.getSize());
        FloatArray w_scale = new FloatArray(1);
        QuantizationTests.quantizeFloatArray(weights, w_quant, w_scale);
        Int8Array weightsQuantized = new Int8Array(inputDim * outputDim);
        int weightBlocksPerRow = inputDim / blockSize;
        HalfFloatArray weightsScales = new HalfFloatArray(outputDim * weightBlocksPerRow);
        QuantizationTests.quantizeWeightsToQ8(weights, weightsQuantized, weightsScales, outputDim, inputDim, blockSize);
        WorkerGrid1D hybridWorker = new WorkerGrid1D(outputDim * localWorkgroupSize);
        hybridWorker.setLocalWork((long)localWorkgroupSize, 1L, 1L);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)hybridWorker);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input, weightsQuantized, weightsScales}).task("t0", QuantizationTests::matrixVectorGenericVectorized, (Object)new KernelContext(), (Object)input, (Object)output, (Object)weightsQuantized, (Object)weightsScales, (Object)inputDim, (Object)outputDim, (Object)localWorkgroupSize).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan tornadoExecutor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            tornadoExecutor.withGridScheduler(scheduler).execute();
        }
        this.matrixVectorSequential(outputSeq, weights, input, inputDim, outputDim);
        for (int i = 0; i < output.getSize(); ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.1f);
        }
    }
}

