/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.codegen;

import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.enums.TornadoDeviceType;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class CodeGenTest
extends TornadoTestBase {
    public static void cascadeKernel(IntArray grayIntegralImage, int imageWidth, int imageHeight, IntArray resultsXY) {
        for (int y = 0; y < imageHeight; ++y) {
            for (int x = 0; x < imageWidth; ++x) {
                int n = grayIntegralImage.get(y * imageWidth + x);
            }
        }
    }

    public static void badCascadeKernel2() {
        for (int id = 0; id < 100; ++id) {
            boolean stillLooksLikeAFace = true;
            for (int stage = 0; stillLooksLikeAFace || stage < 100; ++stage) {
                for (int t = 0; t < id; ++t) {
                    stillLooksLikeAFace = t == 0;
                }
            }
        }
    }

    public static void badCascadeKernel3() {
        for (int id = 0; id < 100; ++id) {
            boolean stillLooksLikeAFace = true;
            for (int stage = 0; stillLooksLikeAFace || stage < 100; ++stage) {
                for (int t = 0; stillLooksLikeAFace && t < id; ++t) {
                    stillLooksLikeAFace = t == 0;
                }
            }
        }
    }

    public static void badCascadeKernel4() {
        for (int id = 0; id < 100; ++id) {
            boolean stillLooksLikeAFace = true;
            for (int stage = 0; stillLooksLikeAFace && stage < id; ++stage) {
                for (int t = 0; t < id; ++t) {
                    stillLooksLikeAFace = t == 0;
                }
            }
        }
    }

    private static void breakStatement(IntArray a) {
        for (int i = 0; i < a.getSize() && a.get(i) != 5; ++i) {
            a.set(i, a.get(i) + 5);
        }
        a.set(0, 0);
    }

    public static void testLocalMemoryAllocation(KernelContext context, int localWorkGroupSize) {
        int threadId = context.localIdx;
        int blockDim = context.localGroupSizeX;
        float[] localArray = context.allocateFloatLocalArray(localWorkGroupSize);
        localArray[threadId] = threadId;
        context.localBarrier();
        if (threadId == 0) {
            float sum = 0.0f;
            for (int i = 0; i < blockDim; ++i) {
                sum += localArray[i];
            }
        }
        context.localBarrier();
    }

    public static void processHeadsFlashAttention(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, IntArray positionHolder, int layer, int contextLength) {
        int i;
        int tid = context.localIdx;
        int gid = context.globalIdx;
        int h = context.groupIdx;
        int localSize = context.localGroupSizeX;
        if (h >= nHeads) {
            return;
        }
        int pos = positionHolder.get(0);
        int loff = layer * contextLength * kvDim;
        int kvHeadIdx = h / kvMul;
        int BLOCK_SIZE_C = 4;
        float[] q_shared = context.allocateFloatLocalArray(headSize);
        float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
        float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
        float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C);
        float[] shared_tile_max_holder = context.allocateFloatLocalArray(1);
        float maxScore = Float.NEGATIVE_INFINITY;
        float sumExp = 0.0f;
        float[] output = new float[headSize];
        for (i = 0; i < headSize; ++i) {
            output[i] = 0.0f;
        }
        for (i = tid; i < headSize; i += localSize) {
            q_shared[i] = q.get(h * headSize + i);
        }
        context.localBarrier();
        for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) {
            int d;
            int tIdxInSeq;
            int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos);
            for (tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) {
                int k_v_idx_in_tile = tIdxInSeq - tileC;
                int tileMemOffset = k_v_idx_in_tile * headSize;
                for (d = 0; d < headSize; ++d) {
                    int kvCacheAbsolutePos = tIdxInSeq;
                    int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d;
                    k_tile[tileMemOffset + d] = key_cache.get(kvOffset);
                    v_tile[tileMemOffset + d] = value_cache.get(kvOffset);
                }
            }
            context.localBarrier();
            for (tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) {
                int score_idx_in_tile = tIdxInSeq - tileC;
                float score = 0.0f;
                for (d = 0; d < headSize; ++d) {
                    score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d];
                }
                s_tile[score_idx_in_tile] = score /= TornadoMath.sqrt((float)headSize);
            }
            context.localBarrier();
            float tileLocalMax = Float.NEGATIVE_INFINITY;
            for (int i2 = 0; i2 <= tileEnd - tileC; ++i2) {
                if (!(s_tile[i2] > tileLocalMax)) continue;
                tileLocalMax = s_tile[i2];
            }
            if (tid == 0) {
                shared_tile_max_holder[0] = tileLocalMax;
            }
            context.localBarrier();
            float currentTileMax = shared_tile_max_holder[0];
            float newMax = Math.max(maxScore, currentTileMax);
            if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) {
                float scale = TornadoMath.exp((float)(maxScore - newMax));
                sumExp *= scale;
                int d2 = 0;
                while (d2 < headSize) {
                    int n = d2++;
                    output[n] = output[n] * scale;
                }
            }
            maxScore = newMax;
            for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; ++t_idx_in_s_tile) {
                float expScore = TornadoMath.exp((float)(s_tile[t_idx_in_s_tile] - maxScore));
                sumExp += expScore;
                for (int d3 = 0; d3 < headSize; ++d3) {
                    int n = d3;
                    output[n] = output[n] + expScore * v_tile[t_idx_in_s_tile * headSize + d3];
                }
            }
            context.localBarrier();
        }
        float normFactor = sumExp > 0.0f ? 1.0f / sumExp : 0.0f;
        for (int d = tid; d < headSize; d += localSize) {
            xb.set(h * headSize + d, output[d] * normFactor);
        }
    }

    @Test
    public void test01() throws TornadoExecutionPlanException {
        TaskGraph taskGraph = new TaskGraph("foo");
        int imageWidth = 512;
        int imageHeight = 512;
        IntArray grayIntegralImage = new IntArray(imageHeight * imageWidth);
        IntArray resultsXY = new IntArray(imageHeight * imageWidth);
        IntStream.range(0, imageHeight * imageHeight).forEach(x -> grayIntegralImage.set(x, x));
        taskGraph.transferToDevice(0, new Object[]{grayIntegralImage}).task("bar", CodeGenTest::cascadeKernel, (Object)grayIntegralImage, (Object)imageWidth, (Object)imageHeight, (Object)resultsXY).transferToHost(1, new Object[]{resultsXY});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
    }

    private boolean isRunningOnCPU() {
        TornadoDevice device = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(0);
        return device.getDeviceType() == TornadoDeviceType.CPU;
    }

    @Test
    public void test02() throws TornadoExecutionPlanException {
        if (this.isRunningOnCPU()) {
            return;
        }
        TaskGraph taskGraph = new TaskGraph("s0").task("t0", CodeGenTest::badCascadeKernel2);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withPreCompilation();
        }
    }

    @Test
    @Ignore
    public void test03() throws TornadoExecutionPlanException {
        if (this.isRunningOnCPU()) {
            return;
        }
        TaskGraph taskGraph = new TaskGraph("s0").task("t0", CodeGenTest::badCascadeKernel3);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withPreCompilation();
        }
    }

    @Test
    public void test04() throws TornadoExecutionPlanException {
        this.assertNotBackendOptimization(TornadoVMBackendType.SPIRV);
        if (this.isRunningOnCPU()) {
            return;
        }
        TaskGraph taskGraph = new TaskGraph("s0").task("t0", CodeGenTest::badCascadeKernel4);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withPreCompilation();
        }
    }

    @Test
    public void test05() throws TornadoExecutionPlanException {
        int size = 8192;
        IntArray a = new IntArray(8192);
        a.init(10);
        a.set(12, 5);
        IntArray serial = new IntArray(8192);
        serial.init(10);
        serial.set(12, 5);
        CodeGenTest.breakStatement(serial);
        TaskGraph taskGraph = new TaskGraph("break").transferToDevice(1, new Object[]{a}).task("task", CodeGenTest::breakStatement, (Object)a).transferToHost(1, new Object[]{a});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 8192; ++i) {
            Assert.assertEquals((long)serial.get(i), (long)a.get(i));
        }
    }

    @Test
    public void test06() throws TornadoExecutionPlanException {
        KernelContext context = new KernelContext();
        int localWorkGroupSize = 256;
        TaskGraph taskGraph = new TaskGraph("localMemoryAllocation").task("task", CodeGenTest::testLocalMemoryAllocation, (Object)context, (Object)localWorkGroupSize);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
    }

    @Test
    public void testFlashAttention() throws TornadoExecutionPlanException {
        int i;
        this.assertNotBackend(TornadoVMBackendType.SPIRV);
        int dim = 2048;
        int nHeads = 32;
        int headSize = 64;
        int numKVHeads = 8;
        int kvMul = 4;
        int kvDim = 512;
        int seqLen = 128;
        int pos = 16;
        boolean layer = false;
        FloatArray query = new FloatArray(2048);
        FloatArray keyCache = new FloatArray(65536);
        FloatArray valueCache = new FloatArray(65536);
        FloatArray output = new FloatArray(2048);
        FloatArray attentionWeights = new FloatArray(544);
        IntArray positionAndLayer = new IntArray(2);
        for (i = 0; i < 2048; ++i) {
            query.set(i, 0.01f * (float)i);
        }
        for (i = 0; i < 65536; ++i) {
            keyCache.set(i, 0.005f * (float)i);
            valueCache.set(i, 0.005f * (float)i);
        }
        for (i = 0; i < 2048; ++i) {
            output.set(i, 0.0f);
        }
        positionAndLayer.set(0, 16);
        positionAndLayer.set(1, 0);
        WorkerGrid1D parallelAttentionWorker = new WorkerGrid1D(32);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)parallelAttentionWorker);
        parallelAttentionWorker.setGlobalWork(128L, 1L, 1L);
        parallelAttentionWorker.setLocalWork(4L, 1L, 1L);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{query, keyCache, valueCache, output, attentionWeights, positionAndLayer}).task("t0", CodeGenTest::processHeadsFlashAttention, (Object)context, (Object)query, (Object)keyCache, (Object)valueCache, (Object)output, (Object)32, (Object)64, (Object)512, (Object)4, (Object)positionAndLayer, (Object)0, (Object)512).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
    }
}

