/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.compute;

import java.awt.Color;
import java.util.Random;
import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.types.HalfFloat;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.arrays.LongArray;
import uk.ac.manchester.tornado.api.types.arrays.ShortArray;
import uk.ac.manchester.tornado.api.types.collections.VectorFloat;
import uk.ac.manchester.tornado.api.types.collections.VectorFloat4;
import uk.ac.manchester.tornado.api.types.images.ImageByte3;
import uk.ac.manchester.tornado.api.types.images.ImageFloat3;
import uk.ac.manchester.tornado.api.types.matrix.Matrix2DFloat;
import uk.ac.manchester.tornado.api.types.matrix.Matrix2DFloat4;
import uk.ac.manchester.tornado.api.types.vectors.Byte3;
import uk.ac.manchester.tornado.api.types.vectors.Float3;
import uk.ac.manchester.tornado.api.types.vectors.Float4;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class ComputeTests
extends TornadoTestBase {
    private static final float DELTA = 0.005f;
    private static final float ESP_SQR = 500.0f;
    private static final int MAX_ITERATIONS = 1000;
    private static final float ZOOM = 1.0f;
    private static final float CX = -0.7f;
    private static final float CY = 0.27015f;
    private static final float MOVE_X = 0.0f;
    private static final float MOVE_Y = 0.0f;
    private static int NROWS = 1024;
    private static int NCOLS = 1024;

    private static void nBody(int numBodies, FloatArray refPos, FloatArray refVel) {
        for (int i = 0; i < numBodies; ++i) {
            int body = 4 * i;
            float[] acc = new float[]{0.0f, 0.0f, 0.0f};
            for (int j = 0; j < numBodies; ++j) {
                float[] r = new float[3];
                int index = 4 * j;
                float distSqr = 0.0f;
                for (int k = 0; k < 3; ++k) {
                    r[k] = refPos.get(index + k) - refPos.get(body + k);
                    distSqr += r[k] * r[k];
                }
                float invDist = (float)(1.0 / Math.sqrt(distSqr + 500.0f));
                float invDistCube = invDist * invDist * invDist;
                float s = refPos.get(index + 3) * invDistCube;
                for (int k = 0; k < 3; ++k) {
                    int n = k;
                    acc[n] = acc[n] + s * r[k];
                }
            }
            for (int k = 0; k < 3; ++k) {
                refPos.set(body + k, refPos.get(body + k) + refPos.get(body + k) * 0.005f + 0.5f * acc[k] * 0.005f * 0.005f);
                refVel.set(body + k, refPos.get(body + k) + acc[k] * 0.005f);
            }
        }
    }

    public static void validate(int numBodies, FloatArray posTornadoVM, FloatArray velTornadoVM, FloatArray posSequential, FloatArray velSequential) {
        for (int i = 0; i < numBodies * 4; ++i) {
            Assert.assertEquals((float)posSequential.get(i), (float)posTornadoVM.get(i), (float)0.1f);
            Assert.assertEquals((float)velSequential.get(i), (float)velTornadoVM.get(i), (float)0.1f);
        }
    }

    public static void computeDFT(FloatArray inreal, FloatArray inimag, FloatArray outreal, FloatArray outimag) {
        int n = inreal.getSize();
        for (int k = 0; k < n; ++k) {
            float sumReal = 0.0f;
            float simImag = 0.0f;
            for (int t = 0; t < n; ++t) {
                float angle = (float)(Math.PI * 2 * (double)t * (double)k / (double)n);
                sumReal = (float)((double)sumReal + ((double)inreal.get(t) * Math.cos(angle) + (double)inimag.get(t) * Math.sin(angle)));
                simImag = (float)((double)simImag + ((double)(-inreal.get(t)) * Math.sin(angle) + (double)inimag.get(t) * Math.cos(angle)));
            }
            outreal.set(k, sumReal);
            outimag.set(k, simImag);
        }
    }

    public static void computeDFTFloat(FloatArray inreal, FloatArray inimag, FloatArray outreal, FloatArray outimag) {
        int n = inreal.getSize();
        for (int k = 0; k < n; ++k) {
            float sumReal = 0.0f;
            float simImag = 0.0f;
            for (int t = 0; t < n; ++t) {
                float angle = 2.0f * TornadoMath.floatPI() * (float)t * (float)k / (float)n;
                sumReal += inreal.get(t) * TornadoMath.cos((float)angle) + inimag.get(t) * TornadoMath.sin((float)angle);
                simImag += -inreal.get(t) * TornadoMath.sin((float)angle) + inimag.get(t) * TornadoMath.cos((float)angle);
            }
            outreal.set(k, sumReal);
            outimag.set(k, simImag);
        }
    }

    public static void computeDFTVector(VectorFloat4 inreal, VectorFloat4 inimag, VectorFloat4 outreal, VectorFloat4 outimag) {
        int n = inreal.getLength();
        for (int k = 0; k < n; ++k) {
            Float4 sumReal = new Float4();
            Float4 simImag = new Float4();
            for (int t = 0; t < n; ++t) {
                float angle = 2.0f * TornadoMath.floatPI() * (float)t * (float)k / (float)n;
                Float4 partA = Float4.mult((Float4)inreal.get(t), (float)TornadoMath.cos((float)angle));
                Float4 partB = Float4.mult((Float4)inimag.get(t), (float)TornadoMath.sin((float)angle));
                Float4 partC = Float4.add((Float4)partA, (Float4)partB);
                sumReal = Float4.add((Float4)sumReal, (Float4)partC);
                Float4 neg = Float4.mult((Float4)inreal.get(t), (Float4)new Float4(-1.0f, -1.0f, -1.0f, -1.0f));
                Float4 partAImag = Float4.mult((Float4)neg, (float)TornadoMath.sin((float)angle));
                Float4 partBImag = Float4.mult((Float4)inimag.get(t), (float)TornadoMath.cos((float)angle));
                Float4 partCImag = Float4.add((Float4)partAImag, (Float4)partBImag);
                simImag = Float4.add((Float4)simImag, (Float4)partCImag);
            }
            outreal.set(k, sumReal);
            outimag.set(k, simImag);
        }
    }

    public static void hilbertComputation(FloatArray output, int rows, int cols) {
        for (int i = 0; i < rows; ++i) {
            for (int j = 0; j < cols; ++j) {
                output.set(i * rows + j, 1.0f / (float)(i + 1 + (j + 1) - 1));
            }
        }
    }

    private static float cnd(float X) {
        float c1 = 0.31938154f;
        float c2 = -0.35656378f;
        float c3 = 1.7814779f;
        float c4 = -1.8212559f;
        float c5 = 1.3302745f;
        float zero = 0.0f;
        float one = 1.0f;
        float two = 2.0f;
        float temp4 = 0.2316419f;
        float oneBySqrt2pi = 0.3989423f;
        float absX = TornadoMath.abs((float)X);
        float t = 1.0f / (1.0f + 0.2316419f * absX);
        float y = 1.0f - 0.3989423f * TornadoMath.exp((float)(-X * X / 2.0f)) * t * (0.31938154f + t * (-0.35656378f + t * (1.7814779f + t * (-1.8212559f + t * 1.3302745f))));
        return X < 0.0f ? 1.0f - y : y;
    }

    private static void blackScholesKernel(FloatArray input, FloatArray callResult, FloatArray putResult) {
        for (int idx = 0; idx < callResult.getSize(); ++idx) {
            float rand = input.get(idx);
            float S_LOWER_LIMIT = 10.0f;
            float S_UPPER_LIMIT = 100.0f;
            float K_LOWER_LIMIT = 10.0f;
            float K_UPPER_LIMIT = 100.0f;
            float T_LOWER_LIMIT = 1.0f;
            float T_UPPER_LIMIT = 10.0f;
            float R_LOWER_LIMIT = 0.01f;
            float R_UPPER_LIMIT = 0.05f;
            float SIGMA_LOWER_LIMIT = 0.01f;
            float SIGMA_UPPER_LIMIT = 0.1f;
            float S = 10.0f * rand + 100.0f * (1.0f - rand);
            float K = 10.0f * rand + 100.0f * (1.0f - rand);
            float T = 1.0f * rand + 10.0f * (1.0f - rand);
            float r = 0.01f * rand + 0.05f * (1.0f - rand);
            float v = 0.01f * rand + 0.1f * (1.0f - rand);
            float d1 = (TornadoMath.log((float)(S / K)) + (r + v * v / 2.0f) * T) / v * TornadoMath.sqrt((float)T);
            float d2 = d1 - v * TornadoMath.sqrt((float)T);
            callResult.set(idx, S * ComputeTests.cnd(d1) - K * TornadoMath.exp((float)(T * -1.0f * r)) * ComputeTests.cnd(d2));
            putResult.set(idx, K * TornadoMath.exp((float)(T * -r)) * ComputeTests.cnd(-d2) - S * ComputeTests.cnd(-d1));
        }
    }

    private static void checkBlackScholes(FloatArray call, FloatArray put, FloatArray callPrice, FloatArray putPrice) {
        double delta = 1.8;
        for (int i = 0; i < call.getSize(); ++i) {
            Assert.assertEquals((double)call.get(i), (double)callPrice.get(i), (double)delta);
            Assert.assertEquals((double)put.get(i), (double)putPrice.get(i), (double)delta);
        }
    }

    private static void computeMontecarlo(FloatArray output, int iterations) {
        for (int j = 0; j < iterations; ++j) {
            long seed = j;
            seed = seed * 25214903917L + 11L & 0xFFFFFFFFFFFFL;
            seed = seed * 25214903917L + 11L & 0xFFFFFFFFFFFFL;
            float x = (float)(seed & 0xFFFFFFFL) / 2.6843546E8f;
            seed = seed * 25214903917L + 11L & 0xFFFFFFFFFFFFL;
            float y = (float)((seed = seed * 25214903917L + 11L & 0xFFFFFFFFFFFFL) & 0xFFFFFFFL) / 2.6843546E8f;
            float dist = (float)Math.sqrt(x * x + y * y);
            if (dist <= 1.0f) {
                output.set(j, 1.0f);
                continue;
            }
            output.set(j, 0.0f);
        }
    }

    public static void mandelbrotFractal(int size, ShortArray output) {
        int iterations = 10000;
        float space = 2.0f / (float)size;
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                float Zr = 0.0f;
                float Zi = 0.0f;
                float Cr = (float)(1 * j) * space - 1.5f;
                float Ci = (float)(1 * i) * space - 1.0f;
                float ZrN = 0.0f;
                float ZiN = 0.0f;
                int y = 0;
                for (int ii = 0; ii < 10000; ++ii) {
                    if (ZiN + ZrN <= 4.0f) {
                        Zi = 2.0f * Zr * Zi + Ci;
                        Zr = 1.0f * ZrN - ZiN + Cr;
                        ZiN = Zi * Zi;
                        ZrN = Zr * Zr;
                        ++y;
                        continue;
                    }
                    ii = 10000;
                }
                float temp = (float)(y * 255) / 10000.0f;
                short r = (short)temp;
                output.set(i * size + j, r);
            }
        }
    }

    private static void euler(int size, LongArray five, LongArray outputA, LongArray outputB, LongArray outputC, LongArray outputD, LongArray outputE) {
        for (int e = 1; e < five.getSize(); ++e) {
            long e5 = five.get(e);
            for (int a = 1; a < five.getSize(); ++a) {
                long a5 = five.get(a);
                for (int b = a; b < size; ++b) {
                    long b5 = five.get(b);
                    for (int c = b; c < size; ++c) {
                        long c5 = five.get(c);
                        for (int d = c; d < size; ++d) {
                            long d5 = five.get(d);
                            if (a5 + b5 + c5 + d5 != e5) continue;
                            outputA.set(e, (long)a);
                            outputB.set(e, (long)b);
                            outputC.set(e, (long)c);
                            outputD.set(e, (long)d);
                            outputE.set(e, (long)e);
                        }
                    }
                }
            }
        }
    }

    public static void renderTrack(ImageByte3 output, ImageFloat3 input) {
        for (int y = 0; y < input.Y(); ++y) {
            for (int x = 0; x < input.X(); ++x) {
                int result = (int)input.get(x, y).getS2();
                output.set(x, y, switch (result) {
                    case 1 -> new Byte3(-128, -128, -128);
                    case -1 -> new Byte3(0, 0, 0);
                    case -2 -> new Byte3(-1, 0, 0);
                    case -3 -> new Byte3(0, -1, 0);
                    case -4 -> new Byte3(0, 0, -1);
                    case -5 -> new Byte3(-1, -1, 0);
                    default -> new Byte3(-1, -128, -128);
                });
            }
        }
    }

    public static void juliaSetTornado(int size, FloatArray hue, FloatArray brightness) {
        for (int ix = 0; ix < size; ++ix) {
            for (int jx = 0; jx < size; ++jx) {
                float k;
                float zx = 1.5f * (float)(ix - size / 2) / (0.5f * (float)size) + 0.0f;
                float zy = (float)(jx - size / 2) / (0.5f * (float)size) + 0.0f;
                for (k = 1000.0f; zx * zx + zy * zy < 4.0f && !(k < 0.0f); k -= 1.0f) {
                    float tmp = zx * zx - zy * zy + -0.7f;
                    zy = 2.0f * zx * zy + 0.27015f;
                    zx = tmp;
                }
                hue.set(ix * size + jx, 1000.0f / k);
                brightness.set(ix * size + jx, k > 0.0f ? 1.0f : 0.0f);
            }
        }
    }

    private static void computeMatrixVector(Matrix2DFloat matrix, VectorFloat vector, VectorFloat output) {
        for (int i = 0; i < matrix.getNumRows(); ++i) {
            float sum = 0.0f;
            for (int j = 0; j < matrix.getNumColumns(); ++j) {
                sum += matrix.get(i, j) * vector.get(j);
            }
            output.set(i, sum);
        }
    }

    private static void computeMatrixVectorFloat4(Matrix2DFloat4 matrix, VectorFloat4 vector, VectorFloat output) {
        for (int i = 0; i < matrix.getNumRows(); ++i) {
            float sum = 0.0f;
            for (int j = 0; j < matrix.getNumColumns(); ++j) {
                sum += Float4.sum((Float4)Float4.mult((Float4)matrix.get(i, j), (Float4)vector.get(j)));
            }
            output.set(i, sum);
        }
    }

    private static void matrixMultiplicationHalfFloats(HalfFloatArray A2, HalfFloatArray B2, HalfFloatArray C, int size) {
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                HalfFloat sum = new HalfFloat(0.0f);
                for (int k = 0; k < size; ++k) {
                    HalfFloat mult = HalfFloat.mult((HalfFloat)A2.get(i * size + k), (HalfFloat)B2.get(k * size + j));
                    sum = HalfFloat.add((HalfFloat)sum, (HalfFloat)mult);
                }
                C.set(i * size + j, sum);
            }
        }
    }

    private static void matrixMultiplicationHalfFloatToFloat(HalfFloatArray A2, HalfFloatArray B2, FloatArray C, int size) {
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                HalfFloat sum = new HalfFloat(0.0f);
                for (int k = 0; k < size; ++k) {
                    HalfFloat mult = HalfFloat.mult((HalfFloat)A2.get(i * size + k), (HalfFloat)B2.get(k * size + j));
                    sum = HalfFloat.add((HalfFloat)sum, (HalfFloat)mult);
                }
                C.set(i * size + j, sum.getFloat32());
            }
        }
    }

    private static void matrixMultiplicationHalfFloatToFloat2(HalfFloatArray A2, HalfFloatArray B2, FloatArray C, int size) {
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                float sum = 0.0f;
                for (int k = 0; k < size; ++k) {
                    float mult = A2.get(i * size + k).getFloat32() * B2.get(k * size + j).getFloat32();
                    sum += mult;
                }
                C.set(i * size + j, sum);
            }
        }
    }

    public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, IntArray positionNlayer, FloatArray wrapAtt) {
        int pos = positionNlayer.get(0);
        int layer = positionNlayer.get(1);
        long loff = layer * seqLen * kvDim;
        for (int h = 0; h < nHeads; ++h) {
            ComputeTests.processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt);
        }
    }

    private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos, FloatArray wrapAtt) {
        int headOffset = h * (pos + 1);
        for (int t = 0; t <= pos; ++t) {
            int kvHeadIdx = h / kvMul;
            int keyOffset = (int)(loff + (long)(t * kvDim) + (long)(kvHeadIdx * headSize));
            float score = 0.0f;
            for (int i = 0; i < headSize; ++i) {
                score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i);
            }
            wrapAtt.set(headOffset + t, score /= (float)Math.sqrt(headSize));
        }
        float maxScore = wrapAtt.get(headOffset);
        for (int t = 1; t <= pos; ++t) {
            float val = wrapAtt.get(headOffset + t);
            if (!(val > maxScore)) continue;
            maxScore = val;
        }
        float sum = 0.0f;
        for (int t = 0; t <= pos; ++t) {
            int idx = headOffset + t;
            float expScore = (float)Math.exp(wrapAtt.get(idx) - maxScore);
            wrapAtt.set(idx, expScore);
            sum += expScore;
        }
        float normFactor = sum > 0.0f ? 1.0f / sum : 1.0f / (float)(pos + 1);
        for (int t = 0; t <= pos; ++t) {
            int idx = headOffset + t;
            wrapAtt.set(idx, wrapAtt.get(idx) * normFactor);
        }
        for (int i = 0; i < headSize; ++i) {
            float weightedSum = 0.0f;
            for (int t = 0; t <= pos; ++t) {
                int kvHeadIdx = h / kvMul;
                int valueOffset = (int)(loff + (long)(t * kvDim) + (long)(kvHeadIdx * headSize));
                weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i);
            }
            allXb.set(h * headSize + i, weightedSum);
        }
    }

    private static void processAttentionSequential(FloatArray query, FloatArray keyCache, FloatArray valueCache, FloatArray output, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, int pos, int layer) {
        for (int h = 0; h < nHeads; ++h) {
            int t;
            int layerOffset = layer * seqLen * kvDim;
            float[] scores = new float[pos + 1];
            for (int t2 = 0; t2 <= pos; ++t2) {
                int kvHeadIndex = h / kvMul;
                int keyOffset = layerOffset + t2 * kvDim + kvHeadIndex * headSize;
                float score = 0.0f;
                for (int i = 0; i < headSize; ++i) {
                    score += query.get(h * headSize + i) * keyCache.get(keyOffset + i);
                }
                scores[t2] = score /= (float)Math.sqrt(headSize);
            }
            float maxScore = Float.NEGATIVE_INFINITY;
            for (int t3 = 0; t3 <= pos; ++t3) {
                if (!(scores[t3] > maxScore)) continue;
                maxScore = scores[t3];
            }
            float sum = 0.0f;
            for (t = 0; t <= pos; ++t) {
                scores[t] = (float)Math.exp(scores[t] - maxScore);
                sum += scores[t];
            }
            for (t = 0; t <= pos; ++t) {
                scores[t] = scores[t] / sum;
            }
            for (int i = 0; i < headSize; ++i) {
                float weightedSum = 0.0f;
                for (int t4 = 0; t4 <= pos; ++t4) {
                    int kvHeadIndex = h / kvMul;
                    int valueOffset = layerOffset + t4 * kvDim + kvHeadIndex * headSize + i;
                    weightedSum += scores[t4] * valueCache.get(valueOffset);
                }
                output.set(h * headSize + i, weightedSum);
            }
        }
    }

    @Test
    public void testNBody() throws TornadoExecutionPlanException {
        int numBodies = 16384;
        FloatArray posSeq = new FloatArray(65536);
        FloatArray velSeq = new FloatArray(65536);
        for (int i = 0; i < posSeq.getSize(); ++i) {
            posSeq.set(i, (float)Math.random());
        }
        velSeq.init(0.0f);
        FloatArray posTornadoVM = new FloatArray(65536);
        FloatArray velTornadoVM = new FloatArray(65536);
        for (int i = 0; i < 65536; ++i) {
            posTornadoVM.set(i, posSeq.get(i));
            velTornadoVM.set(i, velSeq.get(i));
        }
        ComputeTests.nBody(16384, posSeq, velSeq);
        WorkerGrid1D workerGrid = new WorkerGrid1D(16384);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        workerGrid.setGlobalWork(16384L, 1L, 1L);
        workerGrid.setLocalWork(32L, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{posTornadoVM, velTornadoVM}).task("t0", ComputeTests::nBody, (Object)16384, (Object)posTornadoVM, (Object)velTornadoVM).transferToHost(1, new Object[]{posTornadoVM, velTornadoVM});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        ComputeTests.validate(16384, posTornadoVM, velTornadoVM, posSeq, velSeq);
    }

    @Test
    public void testNBodySmall() throws TornadoExecutionPlanException {
        int numBodies = 2048;
        FloatArray posSeq = new FloatArray(8192);
        FloatArray velSeq = new FloatArray(8192);
        for (int i = 0; i < posSeq.getSize(); ++i) {
            posSeq.set(i, (float)Math.random());
        }
        velSeq.init(0.0f);
        FloatArray posTornadoVM = new FloatArray(8192);
        FloatArray velTornadoVM = new FloatArray(8192);
        for (int i = 0; i < 8192; ++i) {
            posTornadoVM.set(i, posSeq.get(i));
            velTornadoVM.set(i, velSeq.get(i));
        }
        ComputeTests.nBody(2048, posSeq, velSeq);
        WorkerGrid1D workerGrid = new WorkerGrid1D(2048);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        workerGrid.setGlobalWork(2048L, 1L, 1L);
        workerGrid.setLocalWork(32L, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{posTornadoVM, velTornadoVM}).task("t0", ComputeTests::nBody, (Object)2048, (Object)posTornadoVM, (Object)velTornadoVM).transferToHost(1, new Object[]{posTornadoVM, velTornadoVM});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        ComputeTests.validate(2048, posTornadoVM, velTornadoVM, posSeq, velSeq);
    }

    @Test
    public void testNBodyBigNoWorker() throws TornadoExecutionPlanException {
        int numBodies = 8192;
        FloatArray posSeq = new FloatArray(32768);
        FloatArray velSeq = new FloatArray(32768);
        for (int i = 0; i < posSeq.getSize(); ++i) {
            posSeq.set(i, (float)Math.random());
        }
        velSeq.init(0.0f);
        FloatArray posTornadoVM = new FloatArray(32768);
        FloatArray velTornadoVM = new FloatArray(32768);
        for (int i = 0; i < 32768; ++i) {
            posTornadoVM.set(i, posSeq.get(i));
            velTornadoVM.set(i, velSeq.get(i));
        }
        ComputeTests.nBody(8192, posSeq, velSeq);
        TaskGraph taskGraph = new TaskGraph("compute").transferToDevice(1, new Object[]{posTornadoVM, velTornadoVM}).task("nbody", ComputeTests::nBody, (Object)8192, (Object)posTornadoVM, (Object)velTornadoVM).transferToHost(1, new Object[]{posTornadoVM, velTornadoVM});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        ComputeTests.validate(8192, posTornadoVM, velTornadoVM, posSeq, velSeq);
    }

    private void validateDFT(int size, FloatArray inReal, FloatArray inImag, FloatArray outReal, FloatArray outImag) {
        FloatArray outRealSeq = new FloatArray(size);
        FloatArray outImagSeq = new FloatArray(size);
        ComputeTests.computeDFT(inReal, inImag, outRealSeq, outImagSeq);
        for (int i = 0; i < size; ++i) {
            Assert.assertEquals((float)outImagSeq.get(i), (float)outImag.get(i), (float)0.1f);
            Assert.assertEquals((float)outRealSeq.get(i), (float)outReal.get(i), (float)0.1f);
        }
    }

    private void validateDFTVector(int size, VectorFloat4 inReal, VectorFloat4 inImag, VectorFloat4 outReal, VectorFloat4 outImag) {
        VectorFloat4 outRealSeq = new VectorFloat4(size);
        VectorFloat4 outImagSeq = new VectorFloat4(size);
        ComputeTests.computeDFTVector(inReal, inImag, outRealSeq, outImagSeq);
        for (int i = 0; i < size; ++i) {
            Float4.isEqual((Float4)outImagSeq.get(i), (Float4)outImag.get(i));
            Float4.isEqual((Float4)outRealSeq.get(i), (Float4)outReal.get(i));
        }
    }

    @Test
    public void testDFTDouble() throws TornadoExecutionPlanException {
        int size = 4096;
        FloatArray inReal = new FloatArray(4096);
        FloatArray inImag = new FloatArray(4096);
        FloatArray outReal = new FloatArray(4096);
        FloatArray outImag = new FloatArray(4096);
        for (int i = 0; i < 4096; ++i) {
            inReal.set(i, 1.0f / (float)(i + 2));
            inImag.set(i, 1.0f / (float)(i + 2));
        }
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{inReal, inImag}).task("t0", ComputeTests::computeDFT, (Object)inReal, (Object)inImag, (Object)outReal, (Object)outImag).transferToHost(1, new Object[]{outReal, outImag});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        this.validateDFT(4096, inReal, inImag, outReal, outImag);
    }

    @Test
    public void testDFTVectorTypes() throws TornadoExecutionPlanException {
        int size = 4096;
        VectorFloat4 inReal = new VectorFloat4(4096);
        VectorFloat4 inImag = new VectorFloat4(4096);
        VectorFloat4 outReal = new VectorFloat4(4096);
        VectorFloat4 outImag = new VectorFloat4(4096);
        for (int i = 0; i < 4096; ++i) {
            float valA = 1.0f / (float)(i + 2);
            float valB = 1.0f / (float)(i + 2);
            inReal.set(i, new Float4(valA, valA, valA, valA));
            inImag.set(i, new Float4(valB, valB, valB, valB));
        }
        TaskGraph taskGraph = new TaskGraph("dft").transferToDevice(0, new Object[]{inReal, inImag}).task("withVectors", ComputeTests::computeDFTVector, (Object)inReal, (Object)inImag, (Object)outReal, (Object)outImag).transferToHost(1, new Object[]{outReal, outImag});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        this.validateDFTVector(4096, inReal, inImag, outReal, outImag);
    }

    @Test
    public void testDFTFloat() throws TornadoExecutionPlanException {
        int size = 4096;
        FloatArray inReal = new FloatArray(4096);
        FloatArray inImag = new FloatArray(4096);
        FloatArray outReal = new FloatArray(4096);
        FloatArray outImag = new FloatArray(4096);
        for (int i = 0; i < 4096; ++i) {
            inReal.set(i, 1.0f / (float)(i + 2));
            inImag.set(i, 1.0f / (float)(i + 2));
        }
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{inReal, inImag}).task("t0", ComputeTests::computeDFTFloat, (Object)inReal, (Object)inImag, (Object)outReal, (Object)outImag).transferToHost(1, new Object[]{outReal, outImag});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        this.validateDFT(4096, inReal, inImag, outReal, outImag);
    }

    @Test
    public void testHilbert() throws TornadoExecutionPlanException {
        FloatArray output = new FloatArray(NROWS * NCOLS);
        TaskGraph taskGraph = new TaskGraph("s0").task("t0", ComputeTests::hilbertComputation, (Object)output, (Object)NROWS, (Object)NCOLS).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        FloatArray seq = new FloatArray(NROWS * NCOLS);
        ComputeTests.hilbertComputation(seq, NROWS, NCOLS);
        for (int i = 0; i < NROWS; ++i) {
            for (int j = 0; j < NCOLS; ++j) {
                Assert.assertEquals((float)seq.get(i * NROWS + j), (float)output.get(i * NROWS + j), (float)0.1f);
            }
        }
    }

    @Test
    public void testBlackScholes() throws TornadoExecutionPlanException {
        Random random = new Random();
        int size = 8192;
        FloatArray input = new FloatArray(8192);
        FloatArray callPrice = new FloatArray(8192);
        FloatArray putPrice = new FloatArray(8192);
        FloatArray seqCall = new FloatArray(8192);
        FloatArray seqPut = new FloatArray(8192);
        IntStream.range(0, 8192).forEach(i -> input.set(i, random.nextFloat()));
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{input}).task("t0", ComputeTests::blackScholesKernel, (Object)input, (Object)callPrice, (Object)putPrice).transferToHost(1, new Object[]{callPrice, putPrice});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        ComputeTests.blackScholesKernel(input, seqCall, seqPut);
        ComputeTests.blackScholesKernel(input, seqCall, seqPut);
        ComputeTests.checkBlackScholes(seqCall, seqPut, callPrice, putPrice);
    }

    @Test
    public void testMontecarlo() throws TornadoExecutionPlanException {
        int size = 8192;
        FloatArray output = new FloatArray(8192);
        FloatArray seq = new FloatArray(8192);
        TaskGraph taskGraph = new TaskGraph("s0").task("t0", ComputeTests::computeMontecarlo, (Object)output, (Object)8192).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        float sumTornado = 0.0f;
        for (int j = 0; j < 8192; ++j) {
            sumTornado += output.get(j);
        }
        sumTornado *= 4.0f;
        ComputeTests.computeMontecarlo(seq, 8192);
        float sumSeq = 0.0f;
        for (int j = 0; j < 8192; ++j) {
            sumSeq += seq.get(j);
        }
        Assert.assertEquals((double)(sumSeq *= 4.0f), (double)sumTornado, (double)0.1);
    }

    private void validateMandelbrot(int size, ShortArray output) {
        ShortArray result = new ShortArray(size * size);
        ComputeTests.mandelbrotFractal(size, result);
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                Assert.assertEquals((long)result.get(i * size + j), (long)output.get(i * size + j));
            }
        }
    }

    @Test
    public void testMandelbrot() throws TornadoExecutionPlanException {
        int size = 512;
        ShortArray output = new ShortArray(262144);
        TaskGraph taskGraph = new TaskGraph("s0").task("t0", ComputeTests::mandelbrotFractal, (Object)512, (Object)output).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        this.validateMandelbrot(512, output);
    }

    private LongArray init(int size) {
        LongArray input = new LongArray(size);
        for (int i = 0; i < size; ++i) {
            input.set(i, (long)i * (long)i * (long)i * (long)i * (long)i);
        }
        return input;
    }

    @Test
    public void testEuler() throws TornadoExecutionPlanException {
        int size = 128;
        LongArray input = this.init(128);
        LongArray outputA = new LongArray(128);
        LongArray outputB = new LongArray(128);
        LongArray outputC = new LongArray(128);
        LongArray outputD = new LongArray(128);
        LongArray outputE = new LongArray(128);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{input}).task("s0", ComputeTests::euler, (Object)128, (Object)input, (Object)outputA, (Object)outputB, (Object)outputC, (Object)outputD, (Object)outputE).transferToHost(1, new Object[]{outputA, outputB, outputC, outputD, outputE});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        LongArray outputAT = new LongArray(128);
        LongArray outputBT = new LongArray(128);
        LongArray outputCT = new LongArray(128);
        LongArray outputDT = new LongArray(128);
        LongArray outputET = new LongArray(128);
        ComputeTests.euler(128, input, outputAT, outputBT, outputCT, outputDT, outputET);
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((long)outputAT.get(i), (long)outputA.get(i));
            Assert.assertEquals((long)outputBT.get(i), (long)outputB.get(i));
            Assert.assertEquals((long)outputCT.get(i), (long)outputC.get(i));
            Assert.assertEquals((long)outputDT.get(i), (long)outputD.get(i));
            Assert.assertEquals((long)outputET.get(i), (long)outputE.get(i));
        }
    }

    @Test
    public void testRenderTrack() throws TornadoExecutionPlanException {
        int n = 2048;
        int m = 2048;
        ImageByte3 outputTornadoVM = new ImageByte3(n, m);
        ImageByte3 outputJava = new ImageByte3(n, m);
        ImageFloat3 input = new ImageFloat3(n, m);
        Random r = new Random();
        for (int i = 0; i < input.X(); ++i) {
            for (int j = 0; j < input.Y(); ++j) {
                float value = (float)r.nextInt(10) * -1.0f;
                input.set(i, j, new Float3((float)i, (float)j, value));
            }
        }
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{input}).task("t0", ComputeTests::renderTrack, (Object)outputTornadoVM, (Object)input).transferToHost(1, new Object[]{outputTornadoVM});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        ComputeTests.renderTrack(outputJava, input);
        for (int x = 0; x < n; ++x) {
            for (int y = 0; y < m; ++y) {
                Assert.assertEquals((double)outputJava.get(x, y).getX(), (double)outputTornadoVM.get(x, y).getX(), (double)0.1);
                Assert.assertEquals((double)outputJava.get(x, y).getY(), (double)outputTornadoVM.get(x, y).getY(), (double)0.1);
                Assert.assertEquals((double)outputJava.get(x, y).getZ(), (double)outputTornadoVM.get(x, y).getZ(), (double)0.1);
            }
        }
    }

    @Test
    public void testJuliaSets() throws TornadoExecutionPlanException {
        int size = 1024;
        FloatArray hue = new FloatArray(0x100000);
        FloatArray brightness = new FloatArray(0x100000);
        IntArray result = new IntArray(0x100000);
        TaskGraph taskGraph = new TaskGraph("s0").task("t0", ComputeTests::juliaSetTornado, (Object)1024, (Object)hue, (Object)brightness).transferToHost(1, new Object[]{hue, brightness});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 1024; ++i) {
            for (int j = 0; j < 1024; ++j) {
                result.set(i * 1024 + j, Color.HSBtoRGB(hue.get(i * 1024 + j) % 1.0f, 1.0f, brightness.get(i * 1024 + j)));
            }
        }
        FloatArray hueSeq = new FloatArray(0x100000);
        FloatArray brightnessSeq = new FloatArray(0x100000);
        ComputeTests.juliaSetTornado(1024, hueSeq, brightnessSeq);
        float delta = 0.01f;
        for (int i = 0; i < hueSeq.getSize(); ++i) {
            Assert.assertEquals((float)hueSeq.get(i), (float)hue.get(i), (float)delta);
            Assert.assertEquals((float)brightnessSeq.get(i), (float)brightness.get(i), (float)delta);
        }
    }

    @Test
    public void matrixVector() throws TornadoExecutionPlanException {
        int size = 4096;
        Matrix2DFloat matrix2DFloat = new Matrix2DFloat(size, size);
        VectorFloat vectorFloat = new VectorFloat(size);
        VectorFloat result = new VectorFloat(size);
        VectorFloat resultSeq = new VectorFloat(size);
        Random r = new Random();
        int s = size;
        IntStream.range(0, size).forEach(idx -> vectorFloat.set(idx, r.nextFloat()));
        IntStream.range(0, size).forEach(idx -> IntStream.range(0, s).forEach(jdx -> matrix2DFloat.set(idx, jdx, r.nextFloat())));
        TaskGraph taskGraph = new TaskGraph("graph").transferToDevice(1, new Object[]{matrix2DFloat, vectorFloat}).task("mv", ComputeTests::computeMatrixVector, (Object)matrix2DFloat, (Object)vectorFloat, (Object)result).transferToHost(1, new Object[]{result});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        ComputeTests.computeMatrixVector(matrix2DFloat, vectorFloat, resultSeq);
        for (int i = 0; i < vectorFloat.size(); ++i) {
            Assert.assertEquals((float)resultSeq.get(i), (float)result.get(i), (float)0.01f);
        }
    }

    @Test
    public void matrixVectorFloat4() throws TornadoExecutionPlanException {
        int i;
        int M = 2048;
        int N = 4096;
        Matrix2DFloat4 matrix2DFloat = new Matrix2DFloat4(M, N);
        VectorFloat4 vectorFloat = new VectorFloat4(N);
        VectorFloat result = new VectorFloat(M);
        Matrix2DFloat inputA = new Matrix2DFloat(M, N * 4);
        VectorFloat inputB = new VectorFloat(N * 4);
        VectorFloat resultSeq = new VectorFloat(M);
        Random r = new Random(11L);
        for (i = 0; i < vectorFloat.getLength(); ++i) {
            Float4 f = new Float4(0.0f, 1.0f, 2.0f, 3.0f);
            int indexI = i * 4;
            inputB.set(indexI, f.getX());
            inputB.set(indexI + 1, f.getY());
            inputB.set(indexI + 2, f.getZ());
            inputB.set(indexI + 3, f.getW());
            vectorFloat.set(i, f);
        }
        for (i = 0; i < matrix2DFloat.getNumRows(); ++i) {
            for (int j = 0; j < matrix2DFloat.getNumColumns(); ++j) {
                Float4 f = new Float4(0.0f, 1.0f, 2.0f, 3.0f);
                matrix2DFloat.set(i, j, f);
                int indexJ = j * 4;
                inputA.set(i, indexJ, f.getX());
                inputA.set(i, indexJ + 1, f.getY());
                inputA.set(i, indexJ + 2, f.getZ());
                inputA.set(i, indexJ + 3, f.getW());
            }
        }
        TaskGraph taskGraph = new TaskGraph("graph").transferToDevice(1, new Object[]{matrix2DFloat, vectorFloat}).task("mv", ComputeTests::computeMatrixVectorFloat4, (Object)matrix2DFloat, (Object)vectorFloat, (Object)result).transferToHost(1, new Object[]{result});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        ComputeTests.computeMatrixVector(inputA, inputB, resultSeq);
        for (int i2 = 0; i2 < result.getLength(); ++i2) {
            Assert.assertEquals((float)resultSeq.get(i2), (float)result.get(i2), (float)0.01f);
        }
    }

    @Test
    public void testHalfFloatMatrixMultiplication() throws TornadoExecutionPlanException {
        int N = 256;
        HalfFloatArray matrixA = new HalfFloatArray(N * N);
        HalfFloatArray matrixB = new HalfFloatArray(N * N);
        HalfFloatArray matrixCSeq = new HalfFloatArray(N * N);
        HalfFloatArray matrixC = new HalfFloatArray(N * N);
        IntStream.range(0, N * N).parallel().forEach(idx -> {
            matrixA.set(idx, new HalfFloat(2.5f));
            matrixB.set(idx, new HalfFloat(3.5f));
        });
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA, matrixB}).task("t0", ComputeTests::matrixMultiplicationHalfFloats, (Object)matrixA, (Object)matrixB, (Object)matrixC, (Object)N).transferToHost(1, new Object[]{matrixC});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executor.execute();
        }
        ComputeTests.matrixMultiplicationHalfFloats(matrixA, matrixB, matrixCSeq, N);
        for (int i = 0; i < N * N; ++i) {
            Assert.assertEquals((float)matrixCSeq.get(i).getFloat32(), (float)matrixC.get(i).getFloat32(), (float)0.005f);
        }
    }

    @Test
    public void testHalfFloatToFloatMatrixMultiplication() throws TornadoExecutionPlanException {
        int N = 256;
        HalfFloatArray matrixA = new HalfFloatArray(N * N);
        HalfFloatArray matrixB = new HalfFloatArray(N * N);
        FloatArray matrixCSeq = new FloatArray(N * N);
        FloatArray matrixC = new FloatArray(N * N);
        IntStream.range(0, N * N).parallel().forEach(idx -> {
            matrixA.set(idx, new HalfFloat(2.5f));
            matrixB.set(idx, new HalfFloat(3.5f));
        });
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA, matrixB}).task("t0", ComputeTests::matrixMultiplicationHalfFloatToFloat, (Object)matrixA, (Object)matrixB, (Object)matrixC, (Object)N).transferToHost(1, new Object[]{matrixC});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executor.execute();
        }
        ComputeTests.matrixMultiplicationHalfFloatToFloat(matrixA, matrixB, matrixCSeq, N);
        for (int i = 0; i < N * N; ++i) {
            Assert.assertEquals((float)matrixCSeq.get(i), (float)matrixC.get(i), (float)0.005f);
        }
    }

    @Test
    public void testHalfFloatToFloatMatrixMultiplication2() throws TornadoExecutionPlanException {
        int N = 256;
        HalfFloatArray matrixA = new HalfFloatArray(N * N);
        HalfFloatArray matrixB = new HalfFloatArray(N * N);
        FloatArray matrixCSeq = new FloatArray(N * N);
        FloatArray matrixC = new FloatArray(N * N);
        IntStream.range(0, N * N).parallel().forEach(idx -> {
            matrixA.set(idx, new HalfFloat(2.5f));
            matrixB.set(idx, new HalfFloat(3.5f));
        });
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA, matrixB}).task("t0", ComputeTests::matrixMultiplicationHalfFloatToFloat2, (Object)matrixA, (Object)matrixB, (Object)matrixC, (Object)N).transferToHost(1, new Object[]{matrixC});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executor.execute();
        }
        ComputeTests.matrixMultiplicationHalfFloatToFloat2(matrixA, matrixB, matrixCSeq, N);
        for (int i = 0; i < N * N; ++i) {
            Assert.assertEquals((float)matrixCSeq.get(i), (float)matrixC.get(i), (float)0.005f);
        }
    }

    @Test
    public void testGroupedQueryAttention() {
        int i;
        int dim = 2048;
        int nHeads = 32;
        int headSize = 64;
        int numKVHeads = 8;
        int kvMul = 4;
        int kvDim = 512;
        int seqLen = 128;
        int pos = 16;
        boolean layer = false;
        FloatArray query = new FloatArray(2048);
        FloatArray keyCache = new FloatArray(65536);
        FloatArray valueCache = new FloatArray(65536);
        FloatArray output = new FloatArray(2048);
        FloatArray attentionWeights = new FloatArray(544);
        IntArray positionAndLayer = new IntArray(2);
        for (i = 0; i < 2048; ++i) {
            query.set(i, 0.01f * (float)i);
        }
        for (i = 0; i < 65536; ++i) {
            keyCache.set(i, 0.005f * (float)i);
            valueCache.set(i, 0.005f * (float)i);
        }
        for (i = 0; i < 2048; ++i) {
            output.set(i, 0.0f);
        }
        positionAndLayer.set(0, 16);
        positionAndLayer.set(1, 0);
        FloatArray expectedOutput = new FloatArray(2048);
        expectedOutput.init(0.0f);
        ComputeTests.processAttentionSequential(query, keyCache, valueCache, expectedOutput, 32, 64, 512, 4, 128, 16, 0);
        TaskGraph taskGraph = new TaskGraph("gqaTest").transferToDevice(0, new Object[]{query, keyCache, valueCache, output, attentionWeights, positionAndLayer}).task("parallel-attention", ComputeTests::processHeadsParallel, (Object)query, (Object)keyCache, (Object)valueCache, (Object)output, (Object)32, (Object)64, (Object)512, (Object)4, (Object)128, (Object)positionAndLayer, (Object)attentionWeights).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executionPlan.execute();
        for (int i2 = 0; i2 < 2048; ++i2) {
            Assert.assertEquals((float)expectedOutput.get(i2), (float)output.get(i2), (float)0.005f);
        }
    }
}

