/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.compute;

import java.util.Random;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TransformerKernelsTest
extends TornadoTestBase {
    private static final float DELTA = 0.001f;
    private final Random random = new Random(7L);

    public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) {
        int gid = context.globalIdx;
        int lid = context.localIdx;
        int groupId = context.groupIdx;
        int groupSize = context.localGroupSizeX;
        float[] localX = context.allocateFloatLocalArray(localMemSize);
        if (gid < size) {
            localX[lid] = x.get(gid);
            localX[lid] = localX[lid] * localX[lid];
        } else {
            localX[lid] = 0.0f;
        }
        for (int stride = groupSize / 2; stride > 0; stride /= 2) {
            context.localBarrier();
            if (lid >= stride) continue;
            int n = lid;
            localX[n] = localX[n] + localX[lid + stride];
        }
        if (lid == 0) {
            output.set(groupId + 1, localX[0]);
        }
        if (gid == 0) {
            float ss = 0.0f;
            for (int i = 1; i < output.getSize(); ++i) {
                ss += output.get(i);
            }
            ss /= (float)size;
            ss += ermsNorm;
            ss = 1.0f / TornadoMath.sqrt((float)ss);
            output.set(0, ss);
        }
    }

    public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) {
        int gid = context.globalIdx;
        float ss = temp.get(0);
        output.set(gid, weights.get(gid) * (ss * x.get(gid)));
    }

    public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) {
        int position = positioNlayer.get(0);
        int loff = layer * contextLength * kvDim;
        int destOffset = loff + position * kvDim;
        for (int i = 0; i < srcValue.getSize(); ++i) {
            destKeyCache.set(destOffset + i, srcKey.get(i));
            destValueCache.set(destOffset + i, srcValue.get(i));
        }
    }

    public static void ropeRotation(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) {
        int i = context.globalIdx * 2;
        int head_dim = i % head_size;
        float freq = 1.0f / TornadoMath.pow((float)50000.0f, (float)((float)head_dim / (float)head_size));
        float val = (float)positionHolder.get(0) * freq;
        float fcr = TornadoMath.cos((float)val);
        float fci = TornadoMath.sin((float)val);
        int rotn = i < kv_dim ? 2 : 1;
        float v0q = sq.get(i);
        float v1q = sq.get(i + 1);
        sq.set(i, v0q * fcr - v1q * fci);
        sq.set(i + 1, v0q * fci + v1q * fcr);
        if (rotn > 1 && i < sk.getSize()) {
            float v0k = sk.get(i);
            float v1k = sk.get(i + 1);
            sk.set(i, v0k * fcr - v1k * fci);
            sk.set(i + 1, v0k * fci + v1k * fcr);
        }
    }

    public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) {
        int pos = positionHolder.get(0);
        int loff = layer * contextLength * kvDim;
        for (int h = 0; h < nHeads; ++h) {
            TransformerKernelsTest.processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt);
        }
    }

    private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos, FloatArray wrapAtt) {
        int headOffset = h * (pos + 1);
        for (int t = 0; t <= pos; ++t) {
            int kvHeadIdx = h / kvMul;
            int keyOffset = (int)(loff + (long)(t * kvDim) + (long)(kvHeadIdx * headSize));
            float score = 0.0f;
            for (int i = 0; i < headSize; ++i) {
                score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i);
            }
            wrapAtt.set(headOffset + t, score /= TornadoMath.sqrt((float)headSize));
        }
        float maxScore = wrapAtt.get(headOffset);
        for (int t = 1; t <= pos; ++t) {
            float val = wrapAtt.get(headOffset + t);
            if (!(val > maxScore)) continue;
            maxScore = val;
        }
        float sum = 0.0f;
        for (int t = 0; t <= pos; ++t) {
            int idx = headOffset + t;
            float expScore = TornadoMath.exp((float)(wrapAtt.get(idx) - maxScore));
            wrapAtt.set(idx, expScore);
            sum += expScore;
        }
        float normFactor = sum > 0.0f ? 1.0f / sum : 1.0f / (float)(pos + 1);
        for (int t = 0; t <= pos; ++t) {
            int idx = headOffset + t;
            wrapAtt.set(idx, wrapAtt.get(idx) * normFactor);
        }
        for (int i = 0; i < headSize; ++i) {
            float weightedSum = 0.0f;
            for (int t = 0; t <= pos; ++t) {
                int kvHeadIdx = h / kvMul;
                int valueOffset = (int)(loff + (long)(t * kvDim) + (long)(kvHeadIdx * headSize));
                weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i);
            }
            allXb.set(h * headSize + i, weightedSum);
        }
    }

    public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int localSize = localWorkGroupSize;
        if (rowId >= d) {
            return;
        }
        float sum = TransformerKernelsTest.matrixVectorRowMajorOptimized(context, localSize, x, w, n, d);
        if (localId == 0) {
            hb.set(rowId, sum);
        }
    }

    public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        int localSize = localWorkGroupSize;
        if (rowId >= d) {
            return;
        }
        float sum = TransformerKernelsTest.matrixVectorRowMajorOptimized(context, localSize, x, w, n, d);
        if (localId == 0) {
            float result = hb.get(rowId) + sum;
            hb.set(rowId, result);
        }
    }

    public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext context, FloatArray x, FloatArray hb, FloatArray w1, FloatArray w3, int n, int d, int localWorkGroupSize) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        if (rowId >= d) {
            return;
        }
        float sum1 = TransformerKernelsTest.matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w1, n, d);
        float sum3 = TransformerKernelsTest.matrixVectorRowMajorOptimized(context, localWorkGroupSize, x, w3, n, d);
        if (localId == 0) {
            float silu = TransformerKernelsTest.siluActivation(sum1);
            float result = silu * sum3;
            hb.set(rowId, result);
        }
    }

    public static float geluActivation(float x) {
        float x3 = x * x * x;
        return 0.5f * x * (1.0f + TornadoMath.tanh((float)(0.797885f * (x + 0.044715f * x3))));
    }

    public static float siluActivation(float x) {
        return x * (1.0f / (1.0f + TornadoMath.exp((float)(-x))));
    }

    public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, FloatArray w, int n, int d) {
        int rowId = context.groupIdx;
        int localId = context.localIdx;
        float[] localSum = context.allocateFloatLocalArray(localSize);
        int rowOffset = rowId * n;
        float partialSum = 0.0f;
        for (int j = localId; j < n; j += localSize) {
            int matrixIdx = rowOffset + j;
            partialSum += w.get(matrixIdx) * x.get(j);
        }
        localSum[localId] = partialSum;
        context.localBarrier();
        for (int stride = localSize / 2; stride > 0; stride >>= 1) {
            if (localId < stride) {
                int n2 = localId;
                localSum[n2] = localSum[n2] + localSum[localId + stride];
            }
            context.localBarrier();
        }
        return localSum[0];
    }

    public static void serialRmsNorm(KernelContext context, FloatArray output, FloatArray x, int size, float epsilon) {
        int gid = context.globalIdx;
        if (gid == 0) {
            float sumOfSquares = 0.0f;
            for (int i = 0; i < size; ++i) {
                float val = x.get(i);
                sumOfSquares += val * val;
            }
            sumOfSquares /= (float)size;
            float scale = 1.0f / TornadoMath.sqrt((float)(sumOfSquares += epsilon));
            output.set(0, scale);
        }
    }

    public static void reductionPartialSums(KernelContext context, FloatArray output, FloatArray x, int size, int localMemSize) {
        int gid = context.globalIdx;
        int lid = context.localIdx;
        int groupId = context.groupIdx;
        int groupSize = context.localGroupSizeX;
        float[] localX = context.allocateFloatLocalArray(localMemSize);
        if (gid < size) {
            localX[lid] = x.get(gid);
            localX[lid] = localX[lid] * localX[lid];
        } else {
            localX[lid] = 0.0f;
        }
        for (int stride = groupSize / 2; stride > 0; stride /= 2) {
            context.localBarrier();
            if (lid >= stride) continue;
            int n = lid;
            localX[n] = localX[n] + localX[lid + stride];
        }
        if (lid == 0) {
            output.set(groupId + 1, localX[0]);
        }
    }

    public static void reductionFinalNormalization(KernelContext context, FloatArray output, int size, float ermsNorm) {
        int gid = context.globalIdx;
        if (gid == 0) {
            float ss = 0.0f;
            for (int i = 1; i < output.getSize(); ++i) {
                ss += output.get(i);
            }
            ss /= (float)size;
            ss += ermsNorm;
            ss = 1.0f / TornadoMath.sqrt((float)ss);
            output.set(0, ss);
        }
    }

    public static void reductionOneBlockSequentialX(FloatArray output, FloatArray x, int size, float ermsNorm) {
        float sum = 0.0f;
        for (int i = 0; i < size; ++i) {
            float val = x.get(i);
            sum += val * val;
        }
        sum /= (float)size;
        sum += ermsNorm;
        sum = 1.0f / (float)Math.sqrt(sum);
        output.set(0, sum);
        int localSize = 128;
        int numWorkGroups = size / localSize;
        for (int g = 0; g < numWorkGroups; ++g) {
            float partialSum = 0.0f;
            int start = g * localSize;
            int end = start + localSize;
            for (int i = start; i < end; ++i) {
                if (i >= size) continue;
                float val = x.get(i);
                partialSum += val * val;
            }
            output.set(g + 1, partialSum);
        }
    }

    private void fillRandomData(FloatArray array, float min, float max) {
        for (int i = 0; i < array.getSize(); ++i) {
            array.set(i, min + this.random.nextFloat() * (max - min));
        }
    }

    private void reductionOneBlockSequential(FloatArray output, FloatArray x, int size, float ermsNorm) {
        float sumOfSquares = 0.0f;
        for (int i = 0; i < size; ++i) {
            float val = x.get(i);
            sumOfSquares += val * val;
        }
        sumOfSquares /= (float)size;
        float scale = 1.0f / (float)Math.sqrt(sumOfSquares += ermsNorm);
        output.set(0, scale);
    }

    private void reductionOneBlock2Sequential(FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) {
        float scale = temp.get(0);
        for (int i = 0; i < x.getSize(); ++i) {
            output.set(i, weights.get(i) * (scale * x.get(i)));
        }
    }

    private void copyToCacheSequential(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positionNlayer, int kvDim, int layer, int contextLength) {
        int position = positionNlayer.get(0);
        int loff = layer * contextLength * kvDim;
        int destOffset = loff + position * kvDim;
        for (int i = 0; i < srcValue.getSize(); ++i) {
            destKeyCache.set(destOffset + i, srcKey.get(i));
            destValueCache.set(destOffset + i, srcValue.get(i));
        }
    }

    private void ropeRotationSequential(IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) {
        for (int i = 0; i < kv_dim; i += 2) {
            int head_dim = i % head_size;
            float freq = 1.0f / (float)Math.pow(50000.0, (float)head_dim / (float)head_size);
            float val = (float)positionHolder.get(0) * freq;
            float fcr = (float)Math.cos(val);
            float fci = (float)Math.sin(val);
            float v0q = sq.get(i);
            float v1q = sq.get(i + 1);
            sq.set(i, v0q * fcr - v1q * fci);
            sq.set(i + 1, v0q * fci + v1q * fcr);
            if (i >= sk.getSize()) continue;
            float v0k = sk.get(i);
            float v1k = sk.get(i + 1);
            sk.set(i, v0k * fcr - v1k * fci);
            sk.set(i + 1, v0k * fci + v1k * fcr);
        }
    }

    private void processHeadsSequential(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, IntArray positionNlayer, FloatArray wrapAtt, int layer, int contextLength) {
        int pos = positionNlayer.get(0);
        int loff = layer * contextLength * kvDim;
        for (int h = 0; h < nHeads; ++h) {
            int headOffset = h * (pos + 1);
            for (int t = 0; t <= pos; ++t) {
                int kvHeadIdx = h / kvMul;
                int keyOffset = loff + t * kvDim + kvHeadIdx * headSize;
                float score = 0.0f;
                for (int i = 0; i < headSize; ++i) {
                    score += q.get(h * headSize + i) * key_cache.get(keyOffset + i);
                }
                wrapAtt.set(headOffset + t, score /= (float)Math.sqrt(headSize));
            }
            float maxScore = wrapAtt.get(headOffset);
            for (int t = 1; t <= pos; ++t) {
                float val = wrapAtt.get(headOffset + t);
                if (!(val > maxScore)) continue;
                maxScore = val;
            }
            float sum = 0.0f;
            for (int t = 0; t <= pos; ++t) {
                int idx = headOffset + t;
                float expScore = (float)Math.exp(wrapAtt.get(idx) - maxScore);
                wrapAtt.set(idx, expScore);
                sum += expScore;
            }
            float normFactor = sum > 0.0f ? 1.0f / sum : 1.0f / (float)(pos + 1);
            for (int t = 0; t <= pos; ++t) {
                int idx = headOffset + t;
                wrapAtt.set(idx, wrapAtt.get(idx) * normFactor);
            }
            for (int i = 0; i < headSize; ++i) {
                float weightedSum = 0.0f;
                for (int t = 0; t <= pos; ++t) {
                    int kvHeadIdx = h / kvMul;
                    int valueOffset = loff + t * kvDim + kvHeadIdx * headSize;
                    weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i);
                }
                xb.set(h * headSize + i, weightedSum);
            }
        }
    }

    private void matrixVectorSequential(FloatArray x, FloatArray hb, FloatArray w, int n, int d) {
        for (int i = 0; i < d; ++i) {
            float sum = 0.0f;
            for (int j = 0; j < n; ++j) {
                sum += w.get(i * n + j) * x.get(j);
            }
            hb.set(i, sum);
        }
    }

    private void matrixVectorWithResidualSequential(FloatArray x, FloatArray hb, FloatArray w, int n, int d) {
        for (int i = 0; i < d; ++i) {
            float sum = 0.0f;
            for (int j = 0; j < n; ++j) {
                sum += w.get(i * n + j) * x.get(j);
            }
            hb.set(i, hb.get(i) + sum);
        }
    }

    private void fusedFeedForwardSequential(FloatArray x, FloatArray hb, FloatArray w1, FloatArray w3, int n, int d) {
        for (int i = 0; i < d; ++i) {
            float sum1 = 0.0f;
            float sum3 = 0.0f;
            for (int j = 0; j < n; ++j) {
                sum1 += w1.get(i * n + j) * x.get(j);
                sum3 += w3.get(i * n + j) * x.get(j);
            }
            float silu = sum1 * (1.0f / (1.0f + (float)Math.exp(-sum1)));
            hb.set(i, silu * sum3);
        }
    }

    @Test
    public void testReductionOneBlockWithLayer() throws TornadoExecutionPlanException {
        int size = 1024;
        int localSize = 128;
        int numWorkGroups = 8;
        float ermsNorm = 1.0E-5f;
        FloatArray input = new FloatArray(1024);
        FloatArray output = new FloatArray(9);
        FloatArray outputSeq = new FloatArray(9);
        this.fillRandomData(input, -2.0f, 2.0f);
        output.init(0.0f);
        this.reductionOneBlockSequential(outputSeq, input, 1024, 1.0E-5f);
        WorkerGrid1D worker = new WorkerGrid1D(1024);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        worker.setLocalWork(128L, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input}).task("t0", TransformerKernelsTest::reductionOneBlockWithLayer, (Object)new KernelContext(), (Object)output, (Object)input, (Object)1024, (Object)Float.valueOf(1.0E-5f), (Object)128).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executionPlan.withGridScheduler(scheduler).execute();
        Assert.assertEquals((float)outputSeq.get(0), (float)output.get(0), (float)0.001f);
        executionPlan.freeDeviceMemory();
    }

    @Test
    public void testReductionOneBlock2WithLayer() throws TornadoExecutionPlanException {
        int size = 1024;
        FloatArray input = new FloatArray(1024);
        FloatArray weights = new FloatArray(1024);
        FloatArray output = new FloatArray(1024);
        FloatArray outputSeq = new FloatArray(1024);
        FloatArray temp = new FloatArray(1);
        this.fillRandomData(input, -2.0f, 2.0f);
        this.fillRandomData(weights, -1.0f, 1.0f);
        temp.set(0, 0.1f);
        this.reductionOneBlock2Sequential(outputSeq, input, weights, temp);
        WorkerGrid1D worker = new WorkerGrid1D(1024);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input, weights, temp}).task("t0", TransformerKernelsTest::reductionOneBlock2WithLayer, (Object)new KernelContext(), (Object)output, (Object)input, (Object)weights, (Object)temp).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(scheduler).execute();
        }
        for (int i = 0; i < 1024; ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.001f);
        }
    }

    @Test
    public void testCopyToCache() throws TornadoExecutionPlanException {
        int kvDim = 128;
        int layer = 2;
        int contextLength = 16;
        int position = 5;
        FloatArray srcKey = new FloatArray(128);
        FloatArray srcValue = new FloatArray(128);
        FloatArray destKeyCache = new FloatArray(8192);
        FloatArray destValueCache = new FloatArray(8192);
        FloatArray destKeyCacheSeq = new FloatArray(8192);
        FloatArray destValueCacheSeq = new FloatArray(8192);
        IntArray positionNlayer = new IntArray(2);
        this.fillRandomData(srcKey, -1.0f, 1.0f);
        this.fillRandomData(srcValue, -1.0f, 1.0f);
        positionNlayer.set(0, 5);
        positionNlayer.set(1, 2);
        this.copyToCacheSequential(destKeyCacheSeq, srcKey, destValueCacheSeq, srcValue, positionNlayer, 128, 2, 16);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{srcKey, srcValue, positionNlayer}).task("t0", TransformerKernelsTest::copyToCache, (Object)destKeyCache, (Object)srcKey, (Object)destValueCache, (Object)srcValue, (Object)positionNlayer, (Object)128, (Object)2, (Object)16).transferToHost(1, new Object[]{destKeyCache, destValueCache});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        int offset = 4736;
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((float)destKeyCacheSeq.get(offset + i), (float)destKeyCache.get(offset + i), (float)0.001f);
            Assert.assertEquals((float)destValueCacheSeq.get(offset + i), (float)destValueCache.get(offset + i), (float)0.001f);
        }
    }

    @Test
    public void testRopeRotation() throws TornadoExecutionPlanException {
        int kvDim = 128;
        int headSize = 64;
        int position = 3;
        FloatArray sq = new FloatArray(128);
        FloatArray sk = new FloatArray(128);
        FloatArray sqSeq = new FloatArray(128);
        FloatArray skSeq = new FloatArray(128);
        IntArray positionHolder = new IntArray(1);
        this.fillRandomData(sq, -1.0f, 1.0f);
        this.fillRandomData(sk, -1.0f, 1.0f);
        for (int i = 0; i < 128; ++i) {
            sqSeq.set(i, sq.get(i));
            skSeq.set(i, sk.get(i));
        }
        positionHolder.set(0, 3);
        this.ropeRotationSequential(positionHolder, sqSeq, skSeq, 128, 64);
        int numPairs = 64;
        WorkerGrid1D worker = new WorkerGrid1D(numPairs);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{sq, sk, positionHolder}).task("t0", TransformerKernelsTest::ropeRotation, (Object)new KernelContext(), (Object)positionHolder, (Object)sq, (Object)sk, (Object)128, (Object)64).transferToHost(1, new Object[]{sq, sk});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(scheduler).execute();
        }
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((float)sqSeq.get(i), (float)sq.get(i), (float)0.001f);
            Assert.assertEquals((float)skSeq.get(i), (float)sk.get(i), (float)0.001f);
        }
    }

    @Test
    public void testProcessHeadsParallel() throws TornadoExecutionPlanException {
        int nHeads = 8;
        int headSize = 64;
        int kvMul = 2;
        int kvDim = 256;
        int seqLen = 8;
        int contextLength = 16;
        int pos = 3;
        boolean layer = true;
        FloatArray query = new FloatArray(512);
        FloatArray keyCache = new FloatArray(8192);
        FloatArray valueCache = new FloatArray(8192);
        FloatArray output = new FloatArray(512);
        FloatArray outputSeq = new FloatArray(512);
        FloatArray attentionWeights = new FloatArray(32);
        FloatArray attentionWeightsSeq = new FloatArray(32);
        IntArray positionHolder = new IntArray(2);
        this.fillRandomData(query, -1.0f, 1.0f);
        this.fillRandomData(keyCache, -0.5f, 0.5f);
        this.fillRandomData(valueCache, -0.5f, 0.5f);
        positionHolder.set(0, 3);
        positionHolder.set(1, 1);
        for (int i = 0; i < query.getSize(); ++i) {
            outputSeq.set(i, 0.0f);
        }
        this.processHeadsSequential(query, keyCache, valueCache, outputSeq, 8, 64, 256, 2, 8, positionHolder, attentionWeightsSeq, 1, 16);
        WorkerGrid1D worker = new WorkerGrid1D(8);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{query, keyCache, valueCache, positionHolder}).task("t0", TransformerKernelsTest::processHeadsParallel, (Object)query, (Object)keyCache, (Object)valueCache, (Object)output, (Object)8, (Object)64, (Object)256, (Object)2, (Object)8, (Object)positionHolder, (Object)attentionWeights, (Object)1, (Object)16).transferToHost(1, new Object[]{output, attentionWeights});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(scheduler).execute();
        }
        for (int i = 0; i < output.getSize(); ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.001f);
        }
    }

    @Test
    public void testMatrixVectorGeneric() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        int inputDim = 64;
        int outputDim = 128;
        int localWorkGroupSize = 32;
        FloatArray input = new FloatArray(64);
        FloatArray weights = new FloatArray(8192);
        FloatArray output = new FloatArray(128);
        FloatArray outputSeq = new FloatArray(128);
        this.fillRandomData(input, -1.0f, 1.0f);
        this.fillRandomData(weights, -0.1f, 0.1f);
        this.matrixVectorSequential(input, outputSeq, weights, 64, 128);
        WorkerGrid1D worker = new WorkerGrid1D(4096);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        worker.setLocalWork(32L, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input, weights}).task("t0", TransformerKernelsTest::matrixVectorGeneric, (Object)new KernelContext(), (Object)input, (Object)output, (Object)weights, (Object)64, (Object)128, (Object)32).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(scheduler).execute();
        }
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.001f);
        }
    }

    @Test
    public void testMatrixVectorGenericWithResidual() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        int inputDim = 64;
        int outputDim = 128;
        int localWorkGroupSize = 32;
        FloatArray input = new FloatArray(64);
        FloatArray weights = new FloatArray(8192);
        FloatArray output = new FloatArray(128);
        FloatArray outputSeq = new FloatArray(128);
        this.fillRandomData(input, -1.0f, 1.0f);
        this.fillRandomData(weights, -0.1f, 0.1f);
        this.fillRandomData(output, -0.5f, 0.5f);
        for (int i = 0; i < 128; ++i) {
            outputSeq.set(i, output.get(i));
        }
        this.matrixVectorWithResidualSequential(input, outputSeq, weights, 64, 128);
        WorkerGrid1D worker = new WorkerGrid1D(4096);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        worker.setLocalWork(32L, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input, weights, output}).task("t0", TransformerKernelsTest::matrixVectorGenericWithResidual, (Object)new KernelContext(), (Object)input, (Object)output, (Object)weights, (Object)64, (Object)128, (Object)32).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(scheduler).execute();
        }
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.001f);
        }
    }

    @Test
    public void testFusedFeedForwardWithSiLUAndGLUActivation() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        int inputDim = 64;
        int hiddenDim = 128;
        int localWorkGroupSize = 32;
        FloatArray input = new FloatArray(64);
        FloatArray w1 = new FloatArray(8192);
        FloatArray w3 = new FloatArray(8192);
        FloatArray output = new FloatArray(128);
        FloatArray outputSeq = new FloatArray(128);
        this.fillRandomData(input, -1.0f, 1.0f);
        this.fillRandomData(w1, -0.1f, 0.1f);
        this.fillRandomData(w3, -0.1f, 0.1f);
        this.fusedFeedForwardSequential(input, outputSeq, w1, w3, 64, 128);
        WorkerGrid1D worker = new WorkerGrid1D(4096);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        worker.setLocalWork(32L, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input, w1, w3}).task("t0", TransformerKernelsTest::fusedFeedForwardWithSiLUAndGLUActivation, (Object)new KernelContext(), (Object)input, (Object)output, (Object)w1, (Object)w3, (Object)64, (Object)128, (Object)32).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(scheduler).execute();
        }
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((float)outputSeq.get(i), (float)output.get(i), (float)0.001f);
        }
    }

    @Test
    public void testActivationFunctions() {
        float[] testValues;
        for (float value : testValues = new float[]{-5.0f, -2.0f, -1.0f, -0.5f, 0.0f, 0.5f, 1.0f, 2.0f, 5.0f}) {
            float geluResult = TransformerKernelsTest.geluActivation(value);
            float siluResult = TransformerKernelsTest.siluActivation(value);
            float x3 = value * value * value;
            float expectedGelu = 0.5f * value * (1.0f + (float)Math.tanh(0.797885f * (value + 0.044715f * x3)));
            float expectedSilu = value * (1.0f / (1.0f + (float)Math.exp(-value)));
            Assert.assertEquals((float)expectedGelu, (float)geluResult, (float)0.001f);
            Assert.assertEquals((float)expectedSilu, (float)siluResult, (float)0.001f);
        }
    }

    private void serialRmsNormSequential(FloatArray output, FloatArray x, int size, float epsilon) {
        float sumOfSquares = 0.0f;
        for (int i = 0; i < size; ++i) {
            float val = x.get(i);
            sumOfSquares += val * val;
        }
        sumOfSquares /= (float)size;
        float scale = 1.0f / TornadoMath.sqrt((float)(sumOfSquares += epsilon));
        output.set(0, scale);
    }

    @Test
    public void testSerialRmsNorm() throws TornadoExecutionPlanException {
        int size = 1024;
        float epsilon = 1.0E-5f;
        FloatArray input = new FloatArray(1024);
        FloatArray output = new FloatArray(1);
        FloatArray outputSeq = new FloatArray(1);
        this.fillRandomData(input, -2.0f, 2.0f);
        this.serialRmsNormSequential(outputSeq, input, 1024, 1.0E-5f);
        WorkerGrid1D worker = new WorkerGrid1D(1);
        GridScheduler scheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input}).task("t0", TransformerKernelsTest::serialRmsNorm, (Object)new KernelContext(), (Object)output, (Object)input, (Object)1024, (Object)Float.valueOf(1.0E-5f)).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(scheduler).execute();
        }
        Assert.assertEquals((float)outputSeq.get(0), (float)output.get(0), (float)0.001f);
    }

    @Test
    public void testReductionOneBlockTwoStepApproach() throws TornadoExecutionPlanException {
        int size = 1024;
        int localSize = 128;
        int numWorkGroups = 8;
        float ermsNorm = 1.0E-5f;
        FloatArray input = new FloatArray(1024);
        FloatArray output = new FloatArray(9);
        FloatArray outputSeq = new FloatArray(9);
        this.fillRandomData(input, -2.0f, 2.0f);
        output.init(0.0f);
        TransformerKernelsTest.reductionOneBlockSequentialX(outputSeq, input, 1024, 1.0E-5f);
        WorkerGrid1D worker1 = new WorkerGrid1D(1024);
        GridScheduler scheduler1 = new GridScheduler("s0.t0", (WorkerGrid)worker1);
        worker1.setLocalWork(128L, 1L, 1L);
        WorkerGrid1D worker2 = new WorkerGrid1D(1);
        GridScheduler scheduler2 = new GridScheduler("s0.t1", (WorkerGrid)worker2);
        worker2.setLocalWork(1L, 1L, 1L);
        scheduler2.addWorkerGrid("s0.t0", (WorkerGrid)worker1);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input, output}).task("t0", TransformerKernelsTest::reductionPartialSums, (Object)new KernelContext(), (Object)output, (Object)input, (Object)1024, (Object)128).task("t1", TransformerKernelsTest::reductionFinalNormalization, (Object)new KernelContext(), (Object)output, (Object)1024, (Object)Float.valueOf(1.0E-5f)).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executionPlan.withGridScheduler(scheduler2).execute();
        Assert.assertEquals((float)outputSeq.get(0), (float)output.get(0), (float)0.001f);
        executionPlan.freeDeviceMemory();
    }
}

