/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.prebuilt;

import java.util.List;
import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import uk.ac.manchester.tornado.api.AccessorParameters;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoBackend;
import uk.ac.manchester.tornado.api.TornadoDeviceMap;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.common.Access;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.enums.TornadoVMBackendType;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.exceptions.TornadoRuntimeException;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;
import uk.ac.manchester.tornado.unittests.common.TornadoVMMultiDeviceNotSupported;
import uk.ac.manchester.tornado.unittests.common.TornadoVMPTXNotSupported;

public class PrebuiltTests
extends TornadoTestBase {
    private static final String TORNADOVM_HOME = "TORNADOVM_HOME";
    private static TornadoDevice defaultDevice;
    private static TornadoVMBackendType backendType;
    private static boolean coops;

    @BeforeClass
    public static void init() {
        backendType = TornadoRuntimeProvider.getTornadoRuntime().getBackendType(0);
        defaultDevice = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(0);
        coops = TornadoNativeArray.ARRAY_HEADER == 16L;
    }

    private String getPrebuiltKernelPath(String kernelName) {
        String sdkPath = System.getenv(TORNADOVM_HOME);
        String basePath = sdkPath + "/examples/generated/";
        Object fileStem = coops ? kernelName : kernelName + "_uncompressed";
        return basePath + (String)fileStem + (switch (backendType) {
            case TornadoVMBackendType.PTX -> ".ptx";
            case TornadoVMBackendType.OPENCL -> ".cl";
            case TornadoVMBackendType.SPIRV -> ".spv";
            default -> throw new TornadoRuntimeException("Backend not supported");
        });
    }

    private String getPrebuiltKernelPath(String kernelName, String extension) {
        String sdkPath = System.getenv(TORNADOVM_HOME);
        String basePath = sdkPath + "/examples/generated/";
        Object fileStem = coops ? kernelName : kernelName + "_uncompressed";
        Object finalExtension = extension.startsWith(".") ? extension : "." + extension;
        return basePath + (String)fileStem + (String)finalExtension;
    }

    @Test
    public void testPrebuilt01() throws TornadoExecutionPlanException {
        int numElements = 8;
        IntArray a = new IntArray(8);
        IntArray b = new IntArray(8);
        IntArray c = new IntArray(8);
        a.init(1);
        b.init(2);
        AccessorParameters accessorParameters = new AccessorParameters(3);
        accessorParameters.set(0, (Object)a, Access.READ_ONLY);
        accessorParameters.set(1, (Object)b, Access.READ_ONLY);
        accessorParameters.set(2, (Object)c, Access.WRITE_ONLY);
        String kernelFile = this.getPrebuiltKernelPath("add");
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).prebuiltTask("t0", "add", kernelFile, accessorParameters).transferToHost(1, new Object[]{c});
        WorkerGrid1D workerGrid = new WorkerGrid1D(8);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{taskGraph.snapshot()});){
            executionPlan.withGridScheduler(gridScheduler).withDevice(defaultDevice).execute();
        }
        for (int j = 0; j < c.getSize(); ++j) {
            Assert.assertEquals((long)(a.get(j) + b.get(j)), (long)c.get(j));
        }
    }

    @Test
    public void testPrebuilt01MultiIterations() throws TornadoExecutionPlanException {
        int numElements = 8;
        IntArray a = new IntArray(8);
        IntArray b = new IntArray(8);
        IntArray c = new IntArray(8);
        a.init(1);
        b.init(2);
        AccessorParameters accessorParameters = new AccessorParameters(3);
        accessorParameters.set(0, (Object)a, Access.READ_WRITE);
        accessorParameters.set(1, (Object)b, Access.READ_WRITE);
        accessorParameters.set(2, (Object)c, Access.WRITE_ONLY);
        String kernelFile = this.getPrebuiltKernelPath("add");
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).prebuiltTask("t0", "add", kernelFile, accessorParameters).transferToHost(1, new Object[]{c});
        WorkerGrid1D workerGrid = new WorkerGrid1D(8);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{taskGraph.snapshot()});){
            executionPlan.withGridScheduler(gridScheduler).withDevice(defaultDevice).execute();
            for (int i = 0; i < 10; ++i) {
                executionPlan.execute();
                for (int j = 0; j < c.getSize(); ++j) {
                    Assert.assertEquals((long)(a.get(j) + b.get(j)), (long)c.get(j));
                }
                IntStream.range(0, 8).forEach(k -> a.set(k, c.get(k)));
            }
        }
    }

    @Test
    public void testPrebuilt02SPIRV() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.PTX);
        this.assertNotBackend(TornadoVMBackendType.OPENCL);
        String kernelFile = this.getPrebuiltKernelPath("reduce03");
        int size = 512;
        int localSize = 256;
        float[] input = new float[512];
        float[] reduce = new float[2];
        IntStream.range(0, input.length).sequential().forEach(i -> {
            input[i] = 1.0f;
        });
        WorkerGrid1D worker = new WorkerGrid1D(512);
        worker.setLocalWork(256L, 1L, 1L);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        AccessorParameters accessorParameters = new AccessorParameters(3);
        accessorParameters.set(0, (Object)context, Access.READ_ONLY);
        accessorParameters.set(1, (Object)input, Access.READ_ONLY);
        accessorParameters.set(2, (Object)reduce, Access.WRITE_ONLY);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input}).prebuiltTask("t0", "floatReductionAddLocalMemory", kernelFile, accessorParameters).transferToHost(1, new Object[]{reduce});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).withDevice(defaultDevice).execute();
        }
        float finalSum = 0.0f;
        for (float v : reduce) {
            finalSum += v;
        }
        Assert.assertEquals((float)512.0f, (float)finalSum, (float)0.0f);
    }

    @Test
    public void testPrebuilt03SPIRV() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.PTX);
        this.assertNotBackend(TornadoVMBackendType.OPENCL);
        String kernelFile = this.getPrebuiltKernelPath("reduce04");
        int size = 32;
        int localSize = 32;
        int[] input = new int[32];
        int[] output = new int[1];
        IntStream.range(0, input.length).sequential().forEach(i -> {
            input[i] = 2;
        });
        WorkerGrid1D worker = new WorkerGrid1D(32);
        worker.setLocalWork(32L, 1L, 1L);
        GridScheduler gridScheduler = new GridScheduler("a.b", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        AccessorParameters accessorParameters = new AccessorParameters(3);
        accessorParameters.set(0, (Object)context, Access.READ_ONLY);
        accessorParameters.set(1, (Object)input, Access.READ_ONLY);
        accessorParameters.set(2, (Object)output, Access.WRITE_ONLY);
        TaskGraph taskGraph = new TaskGraph("a").transferToDevice(0, new Object[]{input}).prebuiltTask("b", "intReductionAddGlobalMemory", kernelFile, accessorParameters).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).withDevice(defaultDevice).execute();
        }
        float finalSum = 0.0f;
        for (int v : output) {
            finalSum += (float)v;
        }
        Assert.assertEquals((float)64.0f, (float)finalSum, (float)0.0f);
    }

    @Test
    public void testPrebuilt04SPIRVThroughOpenCLRuntime() throws TornadoExecutionPlanException {
        this.assertNotBackend(TornadoVMBackendType.PTX);
        TornadoDevice device = this.getSPIRVSupportedDevice();
        if (device == null) {
            this.assertNotBackend(TornadoVMBackendType.OPENCL, "No SPIRV supported device found with the current OpenCL backend. The OpenCL version must be >= 2.1 to support SPIR-V execution.");
        }
        String kernelFile = this.getPrebuiltKernelPath("reduce03", ".spv");
        int size = 512;
        int localSize = 256;
        float[] input = new float[512];
        float[] output = new float[2];
        IntStream.range(0, input.length).sequential().forEach(i -> {
            input[i] = 1.0f;
        });
        WorkerGrid1D worker = new WorkerGrid1D(512);
        worker.setLocalWork(256L, 1L, 1L);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        AccessorParameters accessorParameters = new AccessorParameters(3);
        accessorParameters.set(0, (Object)context, Access.READ_ONLY);
        accessorParameters.set(1, (Object)input, Access.READ_ONLY);
        accessorParameters.set(2, (Object)output, Access.WRITE_ONLY);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{input}).prebuiltTask("t0", "floatReductionAddLocalMemory", kernelFile, accessorParameters).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).withDevice(device).execute();
        }
        float finalSum = 0.0f;
        for (float v : output) {
            finalSum += v;
        }
        Assert.assertEquals((float)512.0f, (float)finalSum, (float)0.0f);
    }

    @Test
    public void testPrebuiltMutiBackend() throws TornadoExecutionPlanException {
        int numElements = 8;
        IntArray a = new IntArray(8);
        IntArray b = new IntArray(8);
        IntArray c = new IntArray(8);
        a.init(1);
        b.init(2);
        String kernelFile = this.getPrebuiltKernelPath("add", ".ptx");
        TornadoDeviceMap tornadoDeviceMap = TornadoExecutionPlan.getTornadoDeviceMap();
        if (tornadoDeviceMap.getNumBackends() < 2) {
            throw new TornadoVMMultiDeviceNotSupported("Test designed to run with multiple backends");
        }
        List ptxBackend = tornadoDeviceMap.getBackendsWithPredicate(backend -> backend.getBackendType() == TornadoVMBackendType.PTX);
        if (ptxBackend == null || ptxBackend.isEmpty()) {
            throw new TornadoVMPTXNotSupported("Test designed to run with multiple backends, including a PTX backend");
        }
        TornadoDevice device = ((TornadoBackend)ptxBackend.getFirst()).getDevice(0);
        AccessorParameters accessorParameters = new AccessorParameters(3);
        accessorParameters.set(0, (Object)a, Access.READ_WRITE);
        accessorParameters.set(1, (Object)b, Access.READ_WRITE);
        accessorParameters.set(2, (Object)c, Access.WRITE_ONLY);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).prebuiltTask("t0", "add", kernelFile, accessorParameters).transferToHost(1, new Object[]{c});
        WorkerGrid1D workerGrid = new WorkerGrid1D(8);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{taskGraph.snapshot()});){
            executionPlan.withGridScheduler(gridScheduler).withDevice(device).execute();
            for (int i = 0; i < 10; ++i) {
                executionPlan.execute();
                for (int j = 0; j < c.getSize(); ++j) {
                    Assert.assertEquals((long)(a.get(j) + b.get(j)), (long)c.get(j));
                }
                IntStream.range(0, 8).forEach(k -> a.set(k, c.get(k)));
            }
        }
    }
}

