/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.examples.compute;

import java.util.ArrayList;
import java.util.LongSummaryStatistics;
import java.util.Random;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoAPIException;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.Int8Array;
import uk.ac.manchester.tornado.api.utils.QuantizationUtils;

public class MatrixVectorRowMajor {
    private static final float DELTA = 1.0E-4f;
    private static final float DELTA_Q = 0.1f;
    private static final int WARM_UP_ITERATIONS = 140;
    private static final int BENCHMARK_ITERATIONS = 120;
    private static final Random random = new Random(42L);
    private static int LOCAL_WORK_GROUP_SIZE = 32;
    private static final int TILE_SIZE = 128;
    private static final int BLOCK_SIZE = 32;

    private static void fillRandomData(FloatArray array, float min, float max) {
        float range = max - min;
        for (int i = 0; i < array.getSize(); ++i) {
            array.set(i, min + random.nextFloat() * range);
        }
    }

    private static void fillRandomDataFp16(HalfFloatArray array, float min, float max) {
        float range = max - min;
        for (int i = 0; i < array.getSize(); ++i) {
            array.set(i, new HalfFloat(min + (float)random.nextInt() * range));
        }
    }

    public static void matrixVectorSequential(FloatArray x, FloatArray hb, FloatArray w, int n, int d) {
        for (int i = 0; i < d; ++i) {
            float sum = 0.0f;
            int rowOffset = i * n;
            for (int j = 0; j < n; ++j) {
                sum += w.get(rowOffset + j) * x.get(j);
            }
            hb.set(i, sum);
        }
    }

    public static void matrixVectorParallel(FloatArray x, FloatArray hb, FloatArray w, int n, int d) {
        for (int i = 0; i < d; ++i) {
            float sum = 0.0f;
            int rowOffset = i * n;
            for (int j = 0; j < n; ++j) {
                sum += w.get(rowOffset + j) * x.get(j);
            }
            hb.set(i, sum);
        }
    }

    public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int localSize = localWorkGroupSize;
        if (rowId >= d) {
            return;
        }
        float sum = MatrixVectorRowMajor.matrixVectorRowMajorOptimized(context, localSize, x, w, n);
        if (localId == 0) {
            hb.set(rowId, sum);
        }
    }

    public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, FloatArray w, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        float[] localSum = context.allocateFloatLocalArray(localSize);
        int rowOffset = rowId * n;
        float partialSum = 0.0f;
        for (int j = localId; j < n; j += localSize) {
            int matrixIdx = rowOffset + j;
            partialSum += w.get(matrixIdx) * x.get(j);
        }
        localSum[localId] = partialSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        return localSum[0];
    }

    public static void matrixVectorGenericFP16(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int localSize = localWorkGroupSize;
        if (rowId >= d) {
            return;
        }
        float sum = MatrixVectorRowMajor.matrixVectorRowMajorOptimizedFP16(context, localSize, x, w, n);
        if (localId == 0) {
            hb.set(rowId, sum);
        }
    }

    public static float matrixVectorRowMajorOptimizedFP16(KernelContext context, int localSize, FloatArray x, HalfFloatArray w, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        float[] localSum = context.allocateFloatLocalArray(localSize);
        int rowOffset = rowId * n;
        float partialSum = 0.0f;
        for (int j = localId; j < n; j += localSize) {
            int matrixIdx = rowOffset + j;
            partialSum += w.get(matrixIdx).getFloat32() * x.get(j);
        }
        localSum[localId] = partialSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        return localSum[0];
    }

    public static void reductionCalculateMax(KernelContext context, FloatArray max, FloatArray x, FloatArray x_scale, FloatArray inv_scale, int localMemSize, int arraySize) {
        int gid = context.globalIdx;
        int lid = context.localIdx;
        int groupId = context.groupIdx;
        int groupSize = context.localGroupSizeX;
        float[] localX = context.allocateFloatLocalArray(localMemSize);
        localX[lid] = gid < arraySize ? TornadoMath.abs((float)x.get(gid)) : 0.0f;
        for (int stride = groupSize / 2; stride > 0; stride /= 2) {
            context.localBarrier();
            if (lid >= stride) continue;
            localX[lid] = TornadoMath.max((float)localX[lid], (float)localX[lid + stride]);
        }
        if (lid == 0) {
            max.set(groupId, localX[0]);
        }
        if (gid == 0) {
            int numGroups = (arraySize + groupSize - 1) / groupSize;
            float max_abs_val = 0.0f;
            for (int i = 0; i < numGroups; ++i) {
                max_abs_val = TornadoMath.max((float)max_abs_val, (float)max.get(i));
            }
            max.set(0, max_abs_val);
            float scale = max_abs_val == 0.0f ? 1.0f : max_abs_val / 127.0f;
            inv_scale.set(0, 1.0f / scale);
            x_scale.set(0, scale);
        }
    }

    public static void quantizeKernelContext(KernelContext context, FloatArray x, FloatArray inv_scale, Int8Array x_quant) {
        int gid = context.globalIdx;
        float scale = inv_scale.get(0);
        x_quant.set(gid, (byte)TornadoMath.floor((float)(x.get(gid) * scale + 0.5f)));
    }

    public static void quantizeWeightsToQ8(FloatArray weightsFP32, Int8Array outQ, HalfFloatArray outScales, ByteArray outQ8ByteArray, int rows, int cols) {
        int Q8_0_BLOCK_BYTES = 34;
        int blocksPerRow = cols / 32;
        for (int r = 0; r < rows; ++r) {
            int rowBase = r * cols;
            for (int b = 0; b < blocksPerRow; ++b) {
                int blockStart = rowBase + b * 32;
                float maxAbs = 0.0f;
                for (int i = 0; i < 32; ++i) {
                    float v = weightsFP32.get(blockStart + i);
                    float a = Math.abs(v);
                    if (!(a > maxAbs)) continue;
                    maxAbs = a;
                }
                float scale = maxAbs == 0.0f ? 0.0f : maxAbs / 127.0f;
                int globalBlockIdx = r * blocksPerRow + b;
                int blockByteOffset = globalBlockIdx * Q8_0_BLOCK_BYTES;
                outScales.set(globalBlockIdx, new HalfFloat(scale));
                outQ8ByteArray.setHalfFloat(blockByteOffset, new HalfFloat(scale));
                float inv = scale == 0.0f ? 0.0f : 1.0f / scale;
                for (int i = 0; i < 32; ++i) {
                    float val = weightsFP32.get(blockStart + i);
                    int q = Math.round(val * inv);
                    if (q > 127) {
                        q = 127;
                    } else if (q < -127) {
                        q = -127;
                    }
                    outQ.set(blockStart + i, (byte)q);
                    outQ8ByteArray.set(blockByteOffset + 2 + i, (byte)q);
                }
            }
        }
    }

    public static void matrixVectorGenericFinal(KernelContext context, FloatArray x, FloatArray output, Int8Array weightsQ, HalfFloatArray weightScales, int dim1, int dim0, int localWorkGroupSize) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (rowId >= dim0) {
            return;
        }
        float sum = MatrixVectorRowMajor.matrixVectorRowMajorOptimizedQ8_0Final(context, localWorkGroupSize, x, weightsQ, weightScales, dim1);
        if (localId == 0) {
            output.set(rowId, sum);
        }
    }

    public static void matrixVectorGenericQ8Byte(KernelContext context, FloatArray x, FloatArray output, ByteArray q, int dim1, int dim0, int localWorkGroupSize) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (rowId >= dim0) {
            return;
        }
        float sum = MatrixVectorRowMajor.matrixVectorRowMajorOptimizedQ8_0Byte(context, localWorkGroupSize, x, q, dim1);
        if (localId == 0) {
            output.set(rowId, sum);
        }
    }

    public static float matrixVectorRowMajorOptimizedQ8_0Byte(KernelContext context, int localSize, FloatArray x, ByteArray q, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int blockSize = 32;
        int Q8_0_BLOCK_BYTES = 34;
        float[] localSums = context.allocateFloatLocalArray(localSize);
        int blocksPerRow = (n + blockSize - 1) / blockSize;
        int rowBlockOffset = rowId * blocksPerRow;
        float partialSum1 = 0.0f;
        float partialSum2 = 0.0f;
        float partialSum3 = 0.0f;
        float partialSum4 = 0.0f;
        for (int j = localId * 4; j < n - 3; j += localSize * 4) {
            int blockIdx = j / blockSize;
            int withinBlockIdx = j % blockSize;
            int blockByteOffset = (rowBlockOffset + blockIdx) * 34;
            HalfFloat scale = q.getHalfFloat(blockByteOffset);
            float scaleFloat = scale.getFloat32();
            int quantsOffset = blockByteOffset + 2 + withinBlockIdx;
            byte quant1 = q.get(quantsOffset);
            byte quant2 = q.get(quantsOffset + 1);
            byte quant3 = q.get(quantsOffset + 2);
            byte quant4 = q.get(quantsOffset + 3);
            partialSum1 += (float)quant1 * scaleFloat * x.get(j);
            partialSum2 += (float)quant2 * scaleFloat * x.get(j + 1);
            partialSum3 += (float)quant3 * scaleFloat * x.get(j + 2);
            partialSum4 += (float)quant4 * scaleFloat * x.get(j + 3);
        }
        float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4;
        for (int j = n / 4 * 4 + localId; j < n; j += localSize) {
            int blockIdx = j / blockSize;
            int withinBlockIdx = j % blockSize;
            int blockByteOffset = (rowBlockOffset + blockIdx) * 34;
            HalfFloat scale = q.getHalfFloat(blockByteOffset);
            float scaleFloat = scale.getFloat32();
            byte quant = q.get(blockByteOffset + 2 + withinBlockIdx);
            partialSum += (float)quant * scaleFloat * x.get(j);
        }
        localSums[localId] = partialSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSums[n2] = localSums[n2] + localSums[localId + stride];
            }
            context.localBarrier();
        }
        return localSums[0];
    }

    public static float matrixVectorRowMajorOptimizedQ8_0Final(KernelContext context, int localSize, FloatArray x, Int8Array weightsQ, HalfFloatArray weightScales, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int blockSize = 32;
        float[] localSums = context.allocateFloatLocalArray(localSize);
        int rowOffset = rowId * n;
        int scalesRowOffset = rowId * (n / blockSize);
        float partialSum1 = 0.0f;
        float partialSum2 = 0.0f;
        float partialSum3 = 0.0f;
        float partialSum4 = 0.0f;
        for (int j = localId * 4; j < n - 3; j += localSize * 4) {
            int blockIdx = j / blockSize;
            float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32();
            partialSum1 += (float)weightsQ.get(rowOffset + j) * scale * x.get(j);
            partialSum2 += (float)weightsQ.get(rowOffset + j + 1) * scale * x.get(j + 1);
            partialSum3 += (float)weightsQ.get(rowOffset + j + 2) * scale * x.get(j + 2);
            partialSum4 += (float)weightsQ.get(rowOffset + j + 3) * scale * x.get(j + 3);
        }
        float partialSum = partialSum1 + partialSum2 + partialSum3 + partialSum4;
        for (int j = n / 4 * 4 + localId; j < n; j += localSize) {
            int blockIdx = j / blockSize;
            float scale = weightScales.get(scalesRowOffset + blockIdx).getFloat32();
            partialSum += (float)weightsQ.get(rowOffset + j) * scale * x.get(j);
        }
        localSums[localId] = partialSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSums[n2] = localSums[n2] + localSums[localId + stride];
            }
            context.localBarrier();
        }
        return localSums[0];
    }

    public static float matrixVectorRowMajorOptimizedDP4A(KernelContext context, int localSize, Int8Array w_quant, Int8Array x_quant, int n, float w_scale, float x_scale) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (localId >= localSize) {
            return 0.0f;
        }
        int[] localSum = context.allocateIntLocalArray(localSize);
        int rowOffset = rowId * n;
        int partialSum = 0;
        for (int j = localId * 4; j < n; j += localSize * 4) {
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(rowOffset + j), (Int8Array)x_quant, (long)j, (int)partialSum);
        }
        localSum[localId] = partialSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        if (localId == 0) {
            return QuantizationUtils.dequantizeFusedResult((int)localSum[0], (float)w_scale, (float)x_scale);
        }
        return 0.0f;
    }

    public static void matrixVectorGenericDP4A(KernelContext context, Int8Array w_quant, Int8Array x_quant, FloatArray output, int n, int d, int localWorkGroupSize, FloatArray w_scale, FloatArray x_scale) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (rowId >= d) {
            return;
        }
        float sum = MatrixVectorRowMajor.matrixVectorRowMajorOptimizedDP4A(context, localWorkGroupSize, w_quant, x_quant, n, w_scale.get(0), x_scale.get(0));
        if (localId == 0) {
            output.set(rowId, sum);
        }
    }

    public static void matrixVectorGenericLocalMemory(KernelContext context, Int8Array w_quant, Int8Array x_quant, FloatArray output, int n, int d, int localWorkGroupSize, FloatArray w_scale, FloatArray x_scale) {
        int rowId = context.groupIdx;
        if (rowId >= d) {
            return;
        }
        int intSum = MatrixVectorRowMajor.matrixVectorDP4ALocalMemory(context, localWorkGroupSize, w_quant, x_quant, n);
        if (context.localIdx == 0) {
            float finalValue = (float)intSum * w_scale.get(0) * x_scale.get(0);
            output.set(rowId, finalValue);
        }
    }

    public static int matrixVectorDP4ALocalMemory(KernelContext context, int localSize, Int8Array w_quant, Int8Array x_quant, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        byte[] x_tile = context.allocateByteLocalArray(128);
        int[] localSum = context.allocateIntLocalArray(localSize);
        int rowOffset = rowId * n;
        int totalSum = 0;
        for (int tileStart = 0; tileStart < n; tileStart += 128) {
            for (int i = localId; i < 128 && tileStart + i < n; i += localSize) {
                x_tile[i] = x_quant.get(tileStart + i);
            }
            context.localBarrier();
            int partialSum = 0;
            for (int j = localId * 4; j < 128; j += localSize * 4) {
                if (tileStart + j >= n) continue;
                partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(rowOffset + tileStart + j), (byte[])x_tile, (long)j, (int)partialSum);
            }
            totalSum += partialSum;
            context.localBarrier();
        }
        localSum[localId] = totalSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        return localSum[0];
    }

    public static void matrixVectorGenericPacked(KernelContext context, Int8Array w_quant, Int8Array x_quant, FloatArray output, int n, int d, int localWorkGroupSize, FloatArray w_scale, FloatArray x_scale) {
        int rowId = context.groupIdx;
        if (rowId >= d) {
            return;
        }
        int intSum = MatrixVectorRowMajor.matrixVectorPacked(context, localWorkGroupSize, w_quant, x_quant, n);
        if (context.localIdx == 0) {
            float finalValue = (float)intSum * w_scale.get(0) * x_scale.get(0);
            output.set(rowId, finalValue);
        }
    }

    public static int matrixVectorPacked(KernelContext context, int localSize, Int8Array w_quant, Int8Array x_quant, int n) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int[] x_tile_packed = context.allocateIntLocalArray(32);
        int[] localSum = context.allocateIntLocalArray(localSize);
        int rowOffset = rowId * n;
        int totalSum = 0;
        int nPacked = n / 4;
        for (int tileStart = 0; tileStart < nPacked; tileStart += 32) {
            int tileSize = Math.min(32, nPacked - tileStart);
            for (int i = localId; i < tileSize; i += localSize) {
                int packed;
                int globalIdx = (tileStart + i) * 4;
                if (globalIdx >= n) continue;
                byte b0 = x_quant.get(globalIdx);
                byte b1 = globalIdx + 1 < n ? x_quant.get(globalIdx + 1) : (byte)0;
                byte b2 = globalIdx + 2 < n ? x_quant.get(globalIdx + 2) : (byte)0;
                byte b3 = globalIdx + 3 < n ? x_quant.get(globalIdx + 3) : (byte)0;
                x_tile_packed[i] = packed = b0 & 0xFF | (b1 & 0xFF) << 8 | (b2 & 0xFF) << 16 | (b3 & 0xFF) << 24;
            }
            context.localBarrier();
            int partialSum = 0;
            for (int i = localId; i < tileSize; i += localSize) {
                int w_globalIdx = rowOffset + (tileStart + i) * 4;
                if (w_globalIdx >= rowOffset + n) continue;
                byte b0 = w_quant.get(w_globalIdx);
                byte b1 = w_globalIdx + 1 < rowOffset + n ? w_quant.get(w_globalIdx + 1) : (byte)0;
                byte b2 = w_globalIdx + 2 < rowOffset + n ? w_quant.get(w_globalIdx + 2) : (byte)0;
                byte b3 = w_globalIdx + 3 < rowOffset + n ? w_quant.get(w_globalIdx + 3) : (byte)0;
                int w_packed = b0 & 0xFF | (b1 & 0xFF) << 8 | (b2 & 0xFF) << 16 | (b3 & 0xFF) << 24;
                partialSum = QuantizationUtils.dp4a_packed((int)w_packed, (int)x_tile_packed[i], (int)partialSum);
            }
            totalSum += partialSum;
            context.localBarrier();
        }
        localSum[localId] = totalSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        return localSum[0];
    }

    private static void quantizeFloatArray(FloatArray x, Int8Array x_quant, FloatArray x_scale) {
        float max_abs_val = 0.0f;
        for (int i = 0; i < x.getSize(); ++i) {
            max_abs_val = TornadoMath.max((float)max_abs_val, (float)TornadoMath.abs((float)x.get(i)));
        }
        float scale = max_abs_val == 0.0f ? 1.0f : max_abs_val / 127.0f;
        float inv_scale = 1.0f / scale;
        for (int i = 0; i < x.getSize(); ++i) {
            x_quant.set(i, (byte)TornadoMath.floor((float)(x.get(i) * inv_scale + 0.5f)));
        }
        x_scale.set(0, scale);
    }

    public static void matrixVectorGeneric4WayDP4A(KernelContext context, Int8Array w_quant, Int8Array x_quant, FloatArray output, int n, int d, int localWorkGroupSize, FloatArray w_scale, FloatArray x_scale) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (rowId >= d) {
            return;
        }
        float sum = MatrixVectorRowMajor.matrixVectorDP4A4Way(context, localWorkGroupSize, w_quant, x_quant, n, w_scale.get(0), x_scale.get(0));
        if (localId == 0) {
            output.set(rowId, sum);
        }
    }

    public static float matrixVectorDP4A4Way(KernelContext context, int localSize, Int8Array w_quant, Int8Array x_quant, int n, float w_scale, float x_scale) {
        int remaining;
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int[] localSum = context.allocateIntLocalArray(localSize);
        int rowOffset = rowId * n;
        float combinedScale = w_scale * x_scale;
        int partialSum = 0;
        int stride = localSize * 16;
        int limit = n - 15;
        for (int j = localId * 16; j < limit; j += stride) {
            int wOffset = rowOffset + j;
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)wOffset, (Int8Array)x_quant, (long)j, (int)partialSum);
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(wOffset + 4), (Int8Array)x_quant, (long)(j + 4), (int)partialSum);
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(wOffset + 8), (Int8Array)x_quant, (long)(j + 8), (int)partialSum);
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(wOffset + 12), (Int8Array)x_quant, (long)(j + 12), (int)partialSum);
        }
        for (int j = remaining = n / 16 * 16 + localId * 4; j < n; j += localSize * 4) {
            if (j + 3 >= n) continue;
            partialSum = QuantizationUtils.dp4a((Int8Array)w_quant, (long)(rowOffset + j), (Int8Array)x_quant, (long)j, (int)partialSum);
        }
        localSum[localId] = partialSum;
        context.localBarrier();
        for (int stride2 = localSize / 2; stride2 > 0; stride2 >>= 1) {
            if (localId < stride2) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride2];
            }
            context.localBarrier();
        }
        return (float)localSum[0] * combinedScale;
    }

    private static boolean isPTXBackend() {
        int driverIndex = TornadoRuntimeProvider.getTornadoRuntime().getDefaultDevice().getBackendIndex();
        TornadoVMBackendType backend = TornadoRuntimeProvider.getTornadoRuntime().getBackendType(driverIndex);
        return backend == TornadoVMBackendType.PTX;
    }

    private static void assertBackend() {
        if (!MatrixVectorRowMajor.isPTXBackend()) {
            throw new TornadoAPIException("DP4A is a PTX instruction. It is not supported for other backends.", new Exception());
        }
    }

    public static void main(String[] args) {
        int i;
        int i2;
        System.out.println("Matrix-Vector Multiplication Benchmark");
        System.out.println("======================================");
        int inputDim = 8192;
        int outputDim = 2048;
        if (args.length >= 3) {
            try {
                inputDim = Integer.parseInt(args[0]);
                outputDim = Integer.parseInt(args[1]);
                LOCAL_WORK_GROUP_SIZE = Integer.parseInt(args[2]);
            }
            catch (NumberFormatException e) {
                System.err.println("Error parsing dimensions. Using defaults.");
            }
        }
        boolean supportsDP4A = MatrixVectorRowMajor.isPTXBackend();
        System.out.println("Configuration:");
        System.out.println("- Input dimension (columns): " + inputDim);
        System.out.println("- Output dimension (rows): " + outputDim);
        System.out.println("- Local work group size: " + LOCAL_WORK_GROUP_SIZE);
        System.out.println("- Backend: " + String.valueOf(TornadoRuntimeProvider.getTornadoRuntime().getBackendType(TornadoRuntimeProvider.getTornadoRuntime().getDefaultDevice().getBackendIndex())));
        System.out.println("- DP4A benchmarks enabled: " + supportsDP4A);
        System.out.println("- Warmup iterations: 140");
        System.out.println("- Benchmark iterations: 120");
        System.out.println();
        FloatArray input = new FloatArray(inputDim);
        FloatArray weights = new FloatArray(inputDim * outputDim);
        HalfFloatArray fp16weights = new HalfFloatArray(inputDim * outputDim);
        FloatArray outputParallel = new FloatArray(outputDim);
        FloatArray outputPureTornado = new FloatArray(outputDim);
        FloatArray outputSeq = new FloatArray(outputDim);
        FloatArray outputQ8Vec = new FloatArray(outputDim);
        FloatArray outputQ8Byte = new FloatArray(outputDim);
        FloatArray outputFp16 = new FloatArray(outputDim);
        FloatArray outputQ8DP4A = supportsDP4A ? new FloatArray(outputDim) : null;
        FloatArray outputQ8DP4APacked = supportsDP4A ? new FloatArray(outputDim) : null;
        FloatArray outputQ8DP4ALocal = supportsDP4A ? new FloatArray(outputDim) : null;
        FloatArray outputQ84DP4A = supportsDP4A ? new FloatArray(outputDim) : null;
        System.out.println("Initializing data...");
        MatrixVectorRowMajor.fillRandomData(input, -1.0f, 1.0f);
        MatrixVectorRowMajor.fillRandomData(weights, -0.1f, 0.1f);
        MatrixVectorRowMajor.fillRandomDataFp16(fp16weights, -0.1f, 0.1f);
        ArrayList<Long> sequentialTimers = new ArrayList<Long>();
        ArrayList<Long> kernelContextTimers = new ArrayList<Long>();
        ArrayList<Long> parallelTimers = new ArrayList<Long>();
        ArrayList<Long> q8VectorizedTimers = new ArrayList<Long>();
        ArrayList<Long> q8ByteTimers = new ArrayList<Long>();
        ArrayList<Long> fp16Timers = new ArrayList<Long>();
        ArrayList<Long> q8Dp4aTimers = supportsDP4A ? new ArrayList<Long>() : null;
        ArrayList<Long> q8Dp4aLocalTimers = supportsDP4A ? new ArrayList<Long>() : null;
        ArrayList<Long> q8Dp4aPackedTimers = supportsDP4A ? new ArrayList<Long>() : null;
        ArrayList<Long> q8Dp4a4WayTimers = supportsDP4A ? new ArrayList<Long>() : null;
        System.out.println("Setting up TornadoVM execution...");
        WorkerGrid1D worker = new WorkerGrid1D(outputDim * LOCAL_WORK_GROUP_SIZE);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        worker.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input, weights}).task("t0", MatrixVectorRowMajor::matrixVectorGeneric, (Object)new KernelContext(), (Object)input, (Object)outputParallel, (Object)weights, (Object)inputDim, (Object)outputDim, (Object)LOCAL_WORK_GROUP_SIZE).transferToHost(1, new Object[]{outputParallel});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TaskGraph taskGraphPure = new TaskGraph("s1").transferToDevice(0, new Object[]{input, weights}).task("t0", MatrixVectorRowMajor::matrixVectorParallel, (Object)input, (Object)outputPureTornado, (Object)weights, (Object)inputDim, (Object)outputDim).transferToHost(1, new Object[]{outputPureTornado});
        ImmutableTaskGraph immutableTaskGraphParallel = taskGraphPure.snapshot();
        WorkerGrid1D workerFp16 = new WorkerGrid1D(outputDim * LOCAL_WORK_GROUP_SIZE);
        GridScheduler schedulerFp16 = new GridScheduler("s3.t0", (WorkerGrid)workerFp16);
        workerFp16.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
        TaskGraph taskGraphFp16 = new TaskGraph("s3").transferToDevice(0, new Object[]{input, fp16weights}).task("t0", MatrixVectorRowMajor::matrixVectorGenericFP16, (Object)new KernelContext(), (Object)input, (Object)outputFp16, (Object)fp16weights, (Object)inputDim, (Object)outputDim, (Object)LOCAL_WORK_GROUP_SIZE).transferToHost(1, new Object[]{outputFp16});
        ImmutableTaskGraph immutableTaskGraphFp16 = taskGraphFp16.snapshot();
        Int8Array weightsQuantized = new Int8Array(inputDim * outputDim);
        int weightBlocksPerRow = inputDim / 32;
        HalfFloatArray weightsScales = new HalfFloatArray(outputDim * weightBlocksPerRow);
        int blockSize = 32;
        int blocksPerRow = (inputDim + blockSize - 1) / blockSize;
        int Q8_0_BLOCK_BYTES = 34;
        int totalQ8Bytes = outputDim * blocksPerRow * Q8_0_BLOCK_BYTES;
        ByteArray q8ByteArray = new ByteArray(totalQ8Bytes);
        MatrixVectorRowMajor.quantizeWeightsToQ8(weights, weightsQuantized, weightsScales, q8ByteArray, outputDim, inputDim);
        WorkerGrid1D q8VectorWorker = new WorkerGrid1D(outputDim * LOCAL_WORK_GROUP_SIZE);
        q8VectorWorker.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
        GridScheduler schedulerQ8Vectorized = new GridScheduler("vectorized.t0", (WorkerGrid)q8VectorWorker);
        TaskGraph taskGraphQ8Vectorized = new TaskGraph("vectorized").transferToDevice(0, new Object[]{input, weightsQuantized, weightsScales}).task("t0", MatrixVectorRowMajor::matrixVectorGenericFinal, (Object)new KernelContext(), (Object)input, (Object)outputQ8Vec, (Object)weightsQuantized, (Object)weightsScales, (Object)inputDim, (Object)outputDim, (Object)LOCAL_WORK_GROUP_SIZE).transferToHost(1, new Object[]{outputQ8Vec});
        ImmutableTaskGraph immutableTaskGraphQ8Vectorized = taskGraphQ8Vectorized.snapshot();
        WorkerGrid1D q8BytesWorker = new WorkerGrid1D(outputDim * LOCAL_WORK_GROUP_SIZE);
        q8BytesWorker.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
        GridScheduler schedulerQ8Bytes = new GridScheduler("q8bytes.t0", (WorkerGrid)q8BytesWorker);
        TaskGraph taskGraphQ8Bytes = new TaskGraph("q8bytes").transferToDevice(0, new Object[]{input, q8ByteArray}).task("t0", MatrixVectorRowMajor::matrixVectorGenericQ8Byte, (Object)new KernelContext(), (Object)input, (Object)outputQ8Byte, (Object)q8ByteArray, (Object)inputDim, (Object)outputDim, (Object)LOCAL_WORK_GROUP_SIZE).transferToHost(1, new Object[]{outputQ8Byte});
        ImmutableTaskGraph immutableTaskGraphQ8Bytes = taskGraphQ8Bytes.snapshot();
        ImmutableTaskGraph immutableTaskGraphDp4a = null;
        ImmutableTaskGraph immutableTaskGraphDp4aPacked = null;
        ImmutableTaskGraph immutableTaskGraphDp4aLocal = null;
        ImmutableTaskGraph immutableTaskGraphDp4a4Way = null;
        GridScheduler schedulerDp4a = null;
        GridScheduler schedulerDp4aPacked = null;
        GridScheduler schedulerDp4aLocalMem = null;
        GridScheduler schedulerDp4a4Way = null;
        if (supportsDP4A) {
            System.out.println("Setting up DP4A benchmarks...");
            Int8Array w_quant = new Int8Array(weights.getSize());
            FloatArray w_scale = new FloatArray(1);
            MatrixVectorRowMajor.quantizeFloatArray(weights, w_quant, w_scale);
            FloatArray x_scale = new FloatArray(1);
            int maxNumGroups = (inputDim + LOCAL_WORK_GROUP_SIZE - 1) / LOCAL_WORK_GROUP_SIZE;
            FloatArray x_max = new FloatArray(maxNumGroups);
            FloatArray inv_scale = new FloatArray(1);
            Int8Array x_quant = new Int8Array(input.getSize());
            WorkerGrid1D workerQuant = new WorkerGrid1D(inputDim);
            WorkerGrid1D workerDp4a = new WorkerGrid1D(LOCAL_WORK_GROUP_SIZE * inputDim);
            schedulerDp4a = new GridScheduler();
            schedulerDp4a.addWorkerGrid("s0_quant_kc.scales", (WorkerGrid)workerQuant);
            schedulerDp4a.addWorkerGrid("s0_quant_kc.quantize", (WorkerGrid)workerQuant);
            schedulerDp4a.addWorkerGrid("s0_quant_kc.dp4amatvec", (WorkerGrid)workerDp4a);
            workerQuant.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
            workerDp4a.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
            TaskGraph taskGraphDp4a = new TaskGraph("s0_quant_kc").transferToDevice(0, new Object[]{input, x_quant, x_scale, x_max, inv_scale, w_quant, w_scale}).task("scales", MatrixVectorRowMajor::reductionCalculateMax, (Object)new KernelContext(), (Object)x_max, (Object)input, (Object)x_scale, (Object)inv_scale, (Object)LOCAL_WORK_GROUP_SIZE, (Object)inputDim).task("quantize", MatrixVectorRowMajor::quantizeKernelContext, (Object)new KernelContext(), (Object)input, (Object)inv_scale, (Object)x_quant).task("dp4amatvec", MatrixVectorRowMajor::matrixVectorGenericDP4A, (Object)new KernelContext(), (Object)w_quant, (Object)x_quant, (Object)outputQ8DP4A, (Object)inputDim, (Object)outputDim, (Object)LOCAL_WORK_GROUP_SIZE, (Object)w_scale, (Object)x_scale).transferToHost(1, new Object[]{outputQ8DP4A});
            immutableTaskGraphDp4a = taskGraphDp4a.snapshot();
            WorkerGrid1D workerQuantPacked = new WorkerGrid1D(inputDim);
            WorkerGrid1D workerDp4aPacked = new WorkerGrid1D(LOCAL_WORK_GROUP_SIZE * inputDim);
            schedulerDp4aPacked = new GridScheduler();
            schedulerDp4aPacked.addWorkerGrid("s0_quant_kc_packed.scales", (WorkerGrid)workerQuantPacked);
            schedulerDp4aPacked.addWorkerGrid("s0_quant_kc_packed.quantize", (WorkerGrid)workerQuantPacked);
            schedulerDp4aPacked.addWorkerGrid("s0_quant_kc_packed.dp4amatvec", (WorkerGrid)workerDp4aPacked);
            workerQuantPacked.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
            workerDp4aPacked.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
            TaskGraph taskGraphDp4aPacked = new TaskGraph("s0_quant_kc_packed").transferToDevice(0, new Object[]{input, x_quant, x_scale, x_max, inv_scale, w_quant, w_scale}).task("scales", MatrixVectorRowMajor::reductionCalculateMax, (Object)new KernelContext(), (Object)x_max, (Object)input, (Object)x_scale, (Object)inv_scale, (Object)LOCAL_WORK_GROUP_SIZE, (Object)inputDim).task("quantize", MatrixVectorRowMajor::quantizeKernelContext, (Object)new KernelContext(), (Object)input, (Object)inv_scale, (Object)x_quant).task("dp4amatvec", MatrixVectorRowMajor::matrixVectorGenericPacked, (Object)new KernelContext(), (Object)w_quant, (Object)x_quant, (Object)outputQ8DP4APacked, (Object)inputDim, (Object)outputDim, (Object)LOCAL_WORK_GROUP_SIZE, (Object)w_scale, (Object)x_scale).transferToHost(1, new Object[]{outputQ8DP4APacked});
            immutableTaskGraphDp4aPacked = taskGraphDp4aPacked.snapshot();
            WorkerGrid1D workerQuantLocalMem = new WorkerGrid1D(inputDim);
            WorkerGrid1D workerDp4aLocalMem = new WorkerGrid1D(LOCAL_WORK_GROUP_SIZE * inputDim);
            schedulerDp4aLocalMem = new GridScheduler();
            schedulerDp4aLocalMem.addWorkerGrid("s0_quant_kc_local.scales", (WorkerGrid)workerQuantLocalMem);
            schedulerDp4aLocalMem.addWorkerGrid("s0_quant_kc_local.quantize", (WorkerGrid)workerQuantLocalMem);
            schedulerDp4aLocalMem.addWorkerGrid("s0_quant_kc_local.dp4amatvec", (WorkerGrid)workerDp4aLocalMem);
            workerQuantLocalMem.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
            workerDp4aLocalMem.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
            TaskGraph taskGraphDp4aLocal = new TaskGraph("s0_quant_kc_local").transferToDevice(0, new Object[]{input, x_quant, x_scale, x_max, inv_scale, w_quant, w_scale}).task("scales", MatrixVectorRowMajor::reductionCalculateMax, (Object)new KernelContext(), (Object)x_max, (Object)input, (Object)x_scale, (Object)inv_scale, (Object)LOCAL_WORK_GROUP_SIZE, (Object)inputDim).task("quantize", MatrixVectorRowMajor::quantizeKernelContext, (Object)new KernelContext(), (Object)input, (Object)inv_scale, (Object)x_quant).task("dp4amatvec", MatrixVectorRowMajor::matrixVectorGenericLocalMemory, (Object)new KernelContext(), (Object)w_quant, (Object)x_quant, (Object)outputQ8DP4ALocal, (Object)inputDim, (Object)outputDim, (Object)LOCAL_WORK_GROUP_SIZE, (Object)w_scale, (Object)x_scale).transferToHost(1, new Object[]{outputQ8DP4ALocal});
            immutableTaskGraphDp4aLocal = taskGraphDp4aLocal.snapshot();
            WorkerGrid1D workerQuant4Way = new WorkerGrid1D(inputDim);
            WorkerGrid1D workerDp4a4way = new WorkerGrid1D(LOCAL_WORK_GROUP_SIZE * inputDim);
            schedulerDp4a4Way = new GridScheduler();
            schedulerDp4a4Way.addWorkerGrid("s0_quant_kc_4way.scales", (WorkerGrid)workerQuant4Way);
            schedulerDp4a4Way.addWorkerGrid("s0_quant_kc_4way.quantize", (WorkerGrid)workerQuant4Way);
            schedulerDp4a4Way.addWorkerGrid("s0_quant_kc_4way.dp4amatvec", (WorkerGrid)workerDp4a4way);
            workerQuant4Way.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
            workerDp4a4way.setLocalWork((long)LOCAL_WORK_GROUP_SIZE, 1L, 1L);
            TaskGraph taskGraphDp4a4Way = new TaskGraph("s0_quant_kc_4way").transferToDevice(0, new Object[]{input, x_quant, x_scale, x_max, inv_scale, w_quant, w_scale}).task("scales", MatrixVectorRowMajor::reductionCalculateMax, (Object)new KernelContext(), (Object)x_max, (Object)input, (Object)x_scale, (Object)inv_scale, (Object)LOCAL_WORK_GROUP_SIZE, (Object)inputDim).task("quantize", MatrixVectorRowMajor::quantizeKernelContext, (Object)new KernelContext(), (Object)input, (Object)inv_scale, (Object)x_quant).task("dp4amatvec", MatrixVectorRowMajor::matrixVectorGeneric4WayDP4A, (Object)new KernelContext(), (Object)w_quant, (Object)x_quant, (Object)outputQ84DP4A, (Object)inputDim, (Object)outputDim, (Object)LOCAL_WORK_GROUP_SIZE, (Object)w_scale, (Object)x_scale).transferToHost(1, new Object[]{outputQ84DP4A});
            immutableTaskGraphDp4a4Way = taskGraphDp4a4Way.snapshot();
        }
        System.out.println("Warming up sequential implementation...");
        for (i2 = 0; i2 < 140; ++i2) {
            MatrixVectorRowMajor.matrixVectorSequential(input, outputSeq, weights, inputDim, outputDim);
        }
        System.out.println("Benchmarking sequential implementation...");
        for (i2 = 0; i2 < 120; ++i2) {
            long start = System.nanoTime();
            MatrixVectorRowMajor.matrixVectorSequential(input, outputSeq, weights, inputDim, outputDim);
            long end = System.nanoTime();
            sequentialTimers.add(end - start);
        }
        TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        TornadoExecutionPlan executionPlan2 = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraphParallel});
        TornadoExecutionPlan executionPlan3 = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraphFp16});
        TornadoExecutionPlan executionPlanQ8Vectorized = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraphQ8Vectorized});
        TornadoExecutionPlan executionPlanQ8Byte = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraphQ8Bytes});
        TornadoExecutionPlan executionPlanQ8Dp4a = supportsDP4A ? new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraphDp4a}) : null;
        TornadoExecutionPlan executionPlanQ8Dp4aPacked = supportsDP4A ? new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraphDp4aPacked}) : null;
        TornadoExecutionPlan executionPlanQ8Dp4aLocal = supportsDP4A ? new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraphDp4aLocal}) : null;
        TornadoExecutionPlan executionPlanQ8Dp4a4Way = supportsDP4A ? new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraphDp4a4Way}) : null;
        System.out.println("Warming up parallel implementation...");
        executionPlan.withGridScheduler(scheduler);
        for (i = 0; i < 140; ++i) {
            executionPlan.withGridScheduler(scheduler).execute();
        }
        for (i = 0; i < 140; ++i) {
            executionPlan2.execute();
        }
        executionPlan3.withGridScheduler(schedulerFp16);
        for (i = 0; i < 140; ++i) {
            executionPlan3.withGridScheduler(schedulerFp16).execute();
        }
        executionPlanQ8Vectorized.withGridScheduler(schedulerQ8Vectorized);
        for (i = 0; i < 140; ++i) {
            executionPlanQ8Vectorized.withGridScheduler(schedulerQ8Vectorized).execute();
        }
        executionPlanQ8Byte.withGridScheduler(schedulerQ8Bytes);
        for (i = 0; i < 140; ++i) {
            executionPlanQ8Byte.withGridScheduler(schedulerQ8Bytes).execute();
        }
        if (supportsDP4A) {
            System.out.println("Warming up DP4A implementations...");
            executionPlanQ8Dp4a.withGridScheduler(schedulerDp4a);
            for (i = 0; i < 140; ++i) {
                executionPlanQ8Dp4a.withGridScheduler(schedulerDp4a).execute();
            }
            executionPlanQ8Dp4aPacked.withGridScheduler(schedulerDp4aPacked);
            for (i = 0; i < 140; ++i) {
                executionPlanQ8Dp4aPacked.withGridScheduler(schedulerDp4aPacked).execute();
            }
            executionPlanQ8Dp4aLocal.withGridScheduler(schedulerDp4aLocalMem);
            for (i = 0; i < 140; ++i) {
                executionPlanQ8Dp4aLocal.withGridScheduler(schedulerDp4aLocalMem).execute();
            }
            executionPlanQ8Dp4a4Way.withGridScheduler(schedulerDp4a4Way);
            for (i = 0; i < 140; ++i) {
                executionPlanQ8Dp4a4Way.withGridScheduler(schedulerDp4a4Way).execute();
            }
        }
        System.out.println("Benchmarking parallel implementation...");
        for (i = 0; i < 120; ++i) {
            long start = System.nanoTime();
            executionPlan.withGridScheduler(scheduler).execute();
            long end = System.nanoTime();
            kernelContextTimers.add(end - start);
        }
        for (i = 0; i < 120; ++i) {
            long start = System.nanoTime();
            executionPlan2.execute();
            long end = System.nanoTime();
            parallelTimers.add(end - start);
        }
        for (i = 0; i < 120; ++i) {
            long start = System.nanoTime();
            executionPlan3.execute();
            long end = System.nanoTime();
            fp16Timers.add(end - start);
        }
        for (i = 0; i < 120; ++i) {
            long start = System.nanoTime();
            executionPlanQ8Vectorized.withGridScheduler(schedulerQ8Vectorized).execute();
            long end = System.nanoTime();
            q8VectorizedTimers.add(end - start);
        }
        for (i = 0; i < 120; ++i) {
            long start = System.nanoTime();
            executionPlanQ8Byte.withGridScheduler(schedulerQ8Bytes).execute();
            long end = System.nanoTime();
            q8ByteTimers.add(end - start);
        }
        if (supportsDP4A) {
            System.out.println("Benchmarking DP4A implementations...");
            for (i = 0; i < 120; ++i) {
                long start = System.nanoTime();
                executionPlanQ8Dp4a.withGridScheduler(schedulerDp4a).execute();
                long end = System.nanoTime();
                q8Dp4aTimers.add(end - start);
            }
            for (i = 0; i < 120; ++i) {
                long start = System.nanoTime();
                executionPlanQ8Dp4aPacked.withGridScheduler(schedulerDp4aPacked).execute();
                long end = System.nanoTime();
                q8Dp4aPackedTimers.add(end - start);
            }
            for (i = 0; i < 120; ++i) {
                long start = System.nanoTime();
                executionPlanQ8Dp4aLocal.withGridScheduler(schedulerDp4aLocalMem).execute();
                long end = System.nanoTime();
                q8Dp4aLocalTimers.add(end - start);
            }
            for (i = 0; i < 120; ++i) {
                long start = System.nanoTime();
                executionPlanQ8Dp4a4Way.withGridScheduler(schedulerDp4a4Way).execute();
                long end = System.nanoTime();
                q8Dp4a4WayTimers.add(end - start);
            }
        }
        System.out.println("Validating results...");
        boolean isValid = true;
        float maxError = 0.0f;
        float maxError2 = 0.0f;
        float maxError3 = 0.0f;
        float maxError4 = 0.0f;
        float maxError5 = 0.0f;
        float maxError6 = 0.0f;
        float maxError7 = 0.0f;
        float maxError8 = 0.0f;
        float maxError9 = 0.0f;
        for (int i3 = 0; i3 < outputDim; ++i3) {
            float error = Math.abs(outputSeq.get(i3) - outputParallel.get(i3));
            maxError = Math.max(maxError, error);
            float error2 = Math.abs(outputSeq.get(i3) - outputPureTornado.get(i3));
            maxError2 = Math.max(maxError2, error2);
            float error3 = Math.abs(outputSeq.get(i3) - outputFp16.get(i3));
            maxError3 = Math.max(maxError3, error3);
            float error4 = Math.abs(outputSeq.get(i3) - outputQ8Vec.get(i3));
            maxError4 = Math.max(maxError4, error4);
            float errorQ8Byte = Math.abs(outputSeq.get(i3) - outputQ8Byte.get(i3));
            maxError9 = Math.max(maxError9, errorQ8Byte);
            if (error > 1.0E-4f) {
                System.out.printf("[KernelContext] Error at index %d: Expected %.6f, Actual %.6f, Diff %.6f\n", i3, Float.valueOf(outputSeq.get(i3)), Float.valueOf(outputParallel.get(i3)), Float.valueOf(error));
                isValid = false;
            }
            if (error2 > 1.0E-4f) {
                System.out.printf("[@Parallel] Error at index %d: Expected %.6f, Actual %.6f, Diff %.6f\n", i3, Float.valueOf(outputSeq.get(i3)), Float.valueOf(outputPureTornado.get(i3)), Float.valueOf(error2));
                isValid = false;
            }
            if (error3 > 1.0E-4f) {
                System.out.printf("[KernelContext FP16] Error at index %d: Expected %.6f, Actual %.6f, Diff %.6f\n", i3, Float.valueOf(outputSeq.get(i3)), Float.valueOf(outputFp16.get(i3)), Float.valueOf(error3));
                isValid = false;
            }
            if (error4 > 0.1f) {
                System.out.printf("[Q8 Vectorized] Error at index %d: Expected %.6f, Actual %.6f, Diff %.6f\n", i3, Float.valueOf(outputSeq.get(i3)), Float.valueOf(outputQ8Vec.get(i3)), Float.valueOf(error4));
                isValid = false;
            }
            if (errorQ8Byte > 0.1f) {
                System.out.printf("[Q8 Byte] Error at index %d: Expected %.6f, Actual %.6f, Diff %.6f\n", i3, Float.valueOf(outputSeq.get(i3)), Float.valueOf(outputQ8Byte.get(i3)), Float.valueOf(errorQ8Byte));
                isValid = false;
            }
            if (!supportsDP4A) continue;
            float error5 = Math.abs(outputSeq.get(i3) - outputQ8DP4A.get(i3));
            maxError5 = Math.max(maxError5, error5);
            float error6 = Math.abs(outputSeq.get(i3) - outputQ8DP4APacked.get(i3));
            maxError6 = Math.max(maxError6, error6);
            float error7 = Math.abs(outputSeq.get(i3) - outputQ8DP4ALocal.get(i3));
            maxError7 = Math.max(maxError7, error7);
            float error8 = Math.abs(outputSeq.get(i3) - outputQ84DP4A.get(i3));
            maxError8 = Math.max(maxError8, error8);
            if (error5 > 0.1f) {
                System.out.printf("[DP4A] Error at index %d: Expected %.6f, Actual %.6f, Diff %.6f\n", i3, Float.valueOf(outputSeq.get(i3)), Float.valueOf(outputQ8DP4A.get(i3)), Float.valueOf(error5));
                isValid = false;
            }
            if (error6 > 0.1f) {
                System.out.printf("[DP4A Packed] Error at index %d: Expected %.6f, Actual %.6f, Diff %.6f\n", i3, Float.valueOf(outputSeq.get(i3)), Float.valueOf(outputQ8DP4APacked.get(i3)), Float.valueOf(error6));
                isValid = false;
            }
            if (error7 > 0.1f) {
                System.out.printf("[DP4A Local Memory] Error at index %d: Expected %.6f, Actual %.6f, Diff %.6f\n", i3, Float.valueOf(outputSeq.get(i3)), Float.valueOf(outputQ8DP4ALocal.get(i3)), Float.valueOf(error7));
                isValid = false;
            }
            if (!(error8 > 0.1f)) continue;
            System.out.printf("[DP4A 4-WAY] Error at index %d: Expected %.6f, Actual %.6f, Diff %.6f\n", i3, Float.valueOf(outputSeq.get(i3)), Float.valueOf(outputQ84DP4A.get(i3)), Float.valueOf(error8));
            isValid = false;
        }
        if (isValid) {
            System.out.println("Validation PASSED \u2713");
        } else {
            System.out.println("[KernelContext] Maximum error: " + maxError);
            System.out.println("[@Parallel] Maximum error: " + maxError2);
            System.out.println("[KernelContext FP16] Maximum error: " + maxError3);
            System.out.println("[Q8 Vectorized] Maximum error: " + maxError4);
            if (supportsDP4A) {
                System.out.println("[Q8 DP4A] Maximum error: " + maxError5);
                System.out.println("[Q8 DP4A Packed] Maximum error: " + maxError6);
                System.out.println("[Q8 DP4A Local Memory] Maximum error: " + maxError7);
                System.out.println("[Q8 DP4A 4-WAY] Maximum error: " + maxError8);
            }
        }
        LongSummaryStatistics statsSeq = sequentialTimers.stream().mapToLong(Long::longValue).summaryStatistics();
        LongSummaryStatistics statsKernelContext = kernelContextTimers.stream().mapToLong(Long::longValue).summaryStatistics();
        LongSummaryStatistics statsParallel = parallelTimers.stream().mapToLong(Long::longValue).summaryStatistics();
        LongSummaryStatistics statsFp16 = fp16Timers.stream().mapToLong(Long::longValue).summaryStatistics();
        LongSummaryStatistics statsQ8Vectorized = q8VectorizedTimers.stream().mapToLong(Long::longValue).summaryStatistics();
        LongSummaryStatistics statsQ8Byte = q8ByteTimers.stream().mapToLong(Long::longValue).summaryStatistics();
        LongSummaryStatistics statsQ8Dp4a = supportsDP4A ? q8Dp4aTimers.stream().mapToLong(Long::longValue).summaryStatistics() : null;
        LongSummaryStatistics statsQ8Dp4aPacked = supportsDP4A ? q8Dp4aPackedTimers.stream().mapToLong(Long::longValue).summaryStatistics() : null;
        LongSummaryStatistics statsQ8Dp4aLocal = supportsDP4A ? q8Dp4aLocalTimers.stream().mapToLong(Long::longValue).summaryStatistics() : null;
        LongSummaryStatistics statsQ8Dp4a4Way = supportsDP4A ? q8Dp4a4WayTimers.stream().mapToLong(Long::longValue).summaryStatistics() : null;
        long flopsPerRow = 2L * (long)inputDim;
        long totalFlops = flopsPerRow * (long)outputDim;
        double seqGFlops = (double)totalFlops * 1.0E-9 / (statsSeq.getAverage() * 1.0E-9);
        double kernelContextGFlops = (double)totalFlops * 1.0E-9 / (statsKernelContext.getAverage() * 1.0E-9);
        double parallelGFlops = (double)totalFlops * 1.0E-9 / (statsParallel.getAverage() * 1.0E-9);
        double fp16GFlops = (double)totalFlops * 1.0E-9 / (statsFp16.getAverage() * 1.0E-9);
        double q8VectorizedGFlops = (double)totalFlops * 1.0E-9 / (statsQ8Vectorized.getAverage() * 1.0E-9);
        double q8ByteGFlops = (double)totalFlops * 1.0E-9 / (statsQ8Byte.getAverage() * 1.0E-9);
        Double q8Dp4aGFlops = supportsDP4A ? Double.valueOf((double)totalFlops * 1.0E-9 / (statsQ8Dp4a.getAverage() * 1.0E-9)) : null;
        Double q8Dp4aPackedGFlops = supportsDP4A ? Double.valueOf((double)totalFlops * 1.0E-9 / (statsQ8Dp4aPacked.getAverage() * 1.0E-9)) : null;
        Double q8Dp4aLocalGFlops = supportsDP4A ? Double.valueOf((double)totalFlops * 1.0E-9 / (statsQ8Dp4aLocal.getAverage() * 1.0E-9)) : null;
        Double q8Dp4a4WayGFlops = supportsDP4A ? Double.valueOf((double)totalFlops * 1.0E-9 / (statsQ8Dp4a4Way.getAverage() * 1.0E-9)) : null;
        System.out.println("\nPerformance Results:");
        System.out.println("====================");
        System.out.printf("Matrix size: %d x %d\n", outputDim, inputDim);
        System.out.println("Sequential Implementation:");
        System.out.printf("  Average time: %.3f ms\n", statsSeq.getAverage() / 1000000.0);
        System.out.printf("  Min time: %.3f ms\n", (double)statsSeq.getMin() / 1000000.0);
        System.out.printf("  Max time: %.3f ms\n", (double)statsSeq.getMax() / 1000000.0);
        System.out.printf("  Performance: %.2f GFLOP/s\n", seqGFlops);
        System.out.println("Parallel Implementation (TornadoVM):");
        System.out.printf("  Average time: %.3f ms\n", statsKernelContext.getAverage() / 1000000.0);
        System.out.printf("  Min time: %.3f ms\n", (double)statsKernelContext.getMin() / 1000000.0);
        System.out.printf("  Max time: %.3f ms\n", (double)statsKernelContext.getMax() / 1000000.0);
        System.out.printf("  Performance: %.2f GFLOP/s\n", kernelContextGFlops);
        System.out.println("Pure TornadoVM @Parallel Implementation (TornadoVM):");
        System.out.printf("  Average time: %.3f ms\n", statsParallel.getAverage() / 1000000.0);
        System.out.printf("  Min time: %.3f ms\n", (double)statsParallel.getMin() / 1000000.0);
        System.out.printf("  Max time: %.3f ms\n", (double)statsParallel.getMax() / 1000000.0);
        System.out.printf("  Performance: %.2f GFLOP/s\n", parallelGFlops);
        System.out.println("Parallel Implementation FP16 (TornadoVM):");
        System.out.printf("  Average time: %.3f ms\n", statsFp16.getAverage() / 1000000.0);
        System.out.printf("  Min time: %.3f ms\n", (double)statsFp16.getMin() / 1000000.0);
        System.out.printf("  Max time: %.3f ms\n", (double)statsFp16.getMax() / 1000000.0);
        System.out.printf("  Performance: %.2f GFLOP/s\n", fp16GFlops);
        System.out.println("Q8 Vectorized:");
        System.out.printf("  Average time: %.3f ms\n", statsQ8Vectorized.getAverage() / 1000000.0);
        System.out.printf("  Min time: %.3f ms\n", (double)statsQ8Vectorized.getMin() / 1000000.0);
        System.out.printf("  Max time: %.3f ms\n", (double)statsQ8Vectorized.getMax() / 1000000.0);
        System.out.printf("  Performance: %.2f GFLOP/s\n", q8VectorizedGFlops);
        System.out.println("Q8 ByteArray:");
        System.out.printf("  Average time: %.3f ms\n", statsQ8Byte.getAverage() / 1000000.0);
        System.out.printf("  Min time: %.3f ms\n", (double)statsQ8Byte.getMin() / 1000000.0);
        System.out.printf("  Max time: %.3f ms\n", (double)statsQ8Byte.getMax() / 1000000.0);
        System.out.printf("  Performance: %.2f GFLOP/s\n", q8ByteGFlops);
        if (supportsDP4A) {
            System.out.println("Q8 DP4A:");
            System.out.printf("  Average time: %.3f ms\n", statsQ8Dp4a.getAverage() / 1000000.0);
            System.out.printf("  Min time: %.3f ms\n", (double)statsQ8Dp4a.getMin() / 1000000.0);
            System.out.printf("  Max time: %.3f ms\n", (double)statsQ8Dp4a.getMax() / 1000000.0);
            System.out.printf("  Performance: %.2f GFLOP/s\n", q8Dp4aGFlops);
            System.out.println("Q8 DP4A Packed:");
            System.out.printf("  Average time: %.3f ms\n", statsQ8Dp4aPacked.getAverage() / 1000000.0);
            System.out.printf("  Min time: %.3f ms\n", (double)statsQ8Dp4aPacked.getMin() / 1000000.0);
            System.out.printf("  Max time: %.3f ms\n", (double)statsQ8Dp4aPacked.getMax() / 1000000.0);
            System.out.printf("  Performance: %.2f GFLOP/s\n", q8Dp4aPackedGFlops);
            System.out.println("Q8 DP4A Local Memory:");
            System.out.printf("  Average time: %.3f ms\n", statsQ8Dp4aLocal.getAverage() / 1000000.0);
            System.out.printf("  Min time: %.3f ms\n", (double)statsQ8Dp4aLocal.getMin() / 1000000.0);
            System.out.printf("  Max time: %.3f ms\n", (double)statsQ8Dp4aLocal.getMax() / 1000000.0);
            System.out.printf("  Performance: %.2f GFLOP/s\n", q8Dp4aLocalGFlops);
            System.out.println("Q8 DP4A 4-WAY:");
            System.out.printf("  Average time: %.3f ms\n", statsQ8Dp4a4Way.getAverage() / 1000000.0);
            System.out.printf("  Min time: %.3f ms\n", (double)statsQ8Dp4a4Way.getMin() / 1000000.0);
            System.out.printf("  Max time: %.3f ms\n", (double)statsQ8Dp4a4Way.getMax() / 1000000.0);
            System.out.printf("  Performance: %.2f GFLOP/s\n", q8Dp4a4WayGFlops);
        }
        double speedup = statsSeq.getAverage() / statsKernelContext.getAverage();
        System.out.printf("\nSpeedup: KernelContext vs Java %.2fx\n", speedup);
        double speedup2 = statsSeq.getAverage() / statsParallel.getAverage();
        System.out.printf("Speedup: @Parallel vs Java %.2fx\n", speedup2);
        double speedup3 = statsParallel.getAverage() / statsKernelContext.getAverage();
        System.out.printf("Speedup: KernelContext vs @Parallel %.2fx\n", speedup3);
        double speedup4 = statsKernelContext.getAverage() / statsQ8Vectorized.getAverage();
        System.out.printf("Speedup: Q8 Vectorized vs KernelContext %.2fx\n", speedup4);
        double speedup5 = statsFp16.getAverage() / statsQ8Vectorized.getAverage();
        System.out.printf("Speedup: Q8 Vectorized vs KernelContext FP16 %.2fx\n", speedup5);
        double speedupQ8Byte = statsKernelContext.getAverage() / statsQ8Byte.getAverage();
        System.out.printf("Speedup: Q8 ByteArray vs KernelContext %.2fx\n", speedupQ8Byte);
        double speedupQ8ByteVsFp16 = statsFp16.getAverage() / statsQ8Byte.getAverage();
        System.out.printf("Speedup: Q8 ByteArray vs KernelContext FP16 %.2fx\n", speedupQ8ByteVsFp16);
        double speedupQ8ByteVsVectorized = statsQ8Vectorized.getAverage() / statsQ8Byte.getAverage();
        System.out.printf("Speedup: Q8 ByteArray vs Q8 Vectorized %.2fx\n", speedupQ8ByteVsVectorized);
        if (supportsDP4A) {
            double speedup6 = statsFp16.getAverage() / statsQ8Dp4a.getAverage();
            System.out.printf("Speedup: Q8 DP4A vs KernelContext FP16 %.2fx\n", speedup6);
            double speedup7 = statsFp16.getAverage() / statsQ8Dp4aPacked.getAverage();
            System.out.printf("Speedup: Q8 DP4A Packed vs KernelContext FP16 %.2fx\n", speedup7);
            double speedup8 = statsFp16.getAverage() / statsQ8Dp4aLocal.getAverage();
            System.out.printf("Speedup: Q8 DP4A Local vs KernelContext FP16 %.2fx\n", speedup8);
            double speedup9 = statsFp16.getAverage() / statsQ8Dp4a4Way.getAverage();
            System.out.printf("Speedup: Q8 DP4A 4-Way vs KernelContext FP16 %.2fx\n", speedup9);
        } else {
            System.out.println("\n[DP4A benchmarks skipped - not running on PTX backend]");
        }
    }
}

