/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.kernelcontext.api;

import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestCombinedTaskGraph
extends TornadoTestBase {
    public static void vectorAddV1(IntArray a, IntArray b, IntArray c) {
        for (int i = 0; i < c.getSize(); ++i) {
            c.set(i, a.get(i) + b.get(i));
        }
    }

    public static void vectorAddV2(KernelContext context, IntArray a, IntArray b, IntArray c) {
        c.set(context.globalIdx.intValue(), a.get(context.globalIdx.intValue()) + b.get(context.globalIdx.intValue()));
    }

    public static void vectorMulV1(IntArray a, IntArray b, IntArray c) {
        for (int i = 0; i < c.getSize(); ++i) {
            c.set(i, a.get(i) * b.get(i));
        }
    }

    public static void vectorMulV2(KernelContext context, IntArray a, IntArray b, IntArray c) {
        c.set(context.globalIdx.intValue(), a.get(context.globalIdx.intValue()) * b.get(context.globalIdx.intValue()));
    }

    public static void vectorSubV1(IntArray a, IntArray b, IntArray c) {
        for (int i = 0; i < c.getSize(); ++i) {
            c.set(i, a.get(i) - b.get(i));
        }
    }

    public static void vectorSubV2(KernelContext context, IntArray a, IntArray b, IntArray c) {
        c.set(context.globalIdx.intValue(), a.get(context.globalIdx.intValue()) - b.get(context.globalIdx.intValue()));
    }

    @Test
    public void combinedAPI01() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntStream.range(0, a.getSize()).sequential().forEach(i -> a.set(i, i));
        IntStream.range(0, b.getSize()).sequential().forEach(i -> b.set(i, i));
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler("s01.t0", (WorkerGrid)worker);
        TaskGraph taskGraph = new TaskGraph("s01").transferToDevice(1, new Object[]{a, b}).task("t0", TestCombinedTaskGraph::vectorAddV1, (Object)a, (Object)b, (Object)cTornado).task("t1", TestCombinedTaskGraph::vectorMulV1, (Object)cTornado, (Object)b, (Object)cTornado).task("t2", TestCombinedTaskGraph::vectorSubV1, (Object)cTornado, (Object)b, (Object)cTornado).transferToHost(1, new Object[]{cTornado});
        worker.setGlobalWork(16L, 1L, 1L);
        worker.setLocalWork(16L, 1L, 1L);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestCombinedTaskGraph.vectorAddV1(a, b, cJava);
        TestCombinedTaskGraph.vectorMulV1(cJava, b, cJava);
        TestCombinedTaskGraph.vectorSubV1(cJava, b, cJava);
        for (int i2 = 0; i2 < 16; ++i2) {
            Assert.assertEquals((long)cJava.get(i2), (long)cTornado.get(i2));
        }
    }

    @Test
    public void combinedAPI02() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntStream.range(0, a.getSize()).sequential().forEach(i -> a.set(i, i));
        IntStream.range(0, b.getSize()).sequential().forEach(i -> b.set(i, i));
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler();
        gridScheduler.addWorkerGrid("s02.t0", (WorkerGrid)worker);
        gridScheduler.addWorkerGrid("s02.t1", (WorkerGrid)worker);
        gridScheduler.addWorkerGrid("s02.t2", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s02").transferToDevice(1, new Object[]{a, b}).task("t0", TestCombinedTaskGraph::vectorAddV2, (Object)context, (Object)a, (Object)b, (Object)cTornado).task("t1", TestCombinedTaskGraph::vectorMulV2, (Object)context, (Object)cTornado, (Object)b, (Object)cTornado).task("t2", TestCombinedTaskGraph::vectorSubV2, (Object)context, (Object)cTornado, (Object)b, (Object)cTornado).transferToHost(1, new Object[]{cTornado});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestCombinedTaskGraph.vectorAddV1(a, b, cJava);
        TestCombinedTaskGraph.vectorMulV1(cJava, b, cJava);
        TestCombinedTaskGraph.vectorSubV1(cJava, b, cJava);
        for (int i2 = 0; i2 < 16; ++i2) {
            Assert.assertEquals((long)cJava.get(i2), (long)cTornado.get(i2));
        }
    }

    @Test
    public void combinedAPI03() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntStream.range(0, a.getSize()).sequential().forEach(i -> a.set(i, i));
        IntStream.range(0, b.getSize()).sequential().forEach(i -> b.set(i, i));
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler();
        gridScheduler.addWorkerGrid("s03.t1", (WorkerGrid)worker);
        gridScheduler.addWorkerGrid("s03.t2", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s03").transferToDevice(1, new Object[]{a, b}).task("t0", TestCombinedTaskGraph::vectorAddV1, (Object)a, (Object)b, (Object)cTornado).task("t1", TestCombinedTaskGraph::vectorMulV2, (Object)context, (Object)cTornado, (Object)b, (Object)cTornado).task("t2", TestCombinedTaskGraph::vectorSubV2, (Object)context, (Object)cTornado, (Object)b, (Object)cTornado).transferToHost(1, new Object[]{cTornado});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestCombinedTaskGraph.vectorAddV1(a, b, cJava);
        TestCombinedTaskGraph.vectorMulV1(cJava, b, cJava);
        TestCombinedTaskGraph.vectorSubV1(cJava, b, cJava);
        for (int i2 = 0; i2 < 16; ++i2) {
            Assert.assertEquals((long)cJava.get(i2), (long)cTornado.get(i2));
        }
    }

    @Test
    public void combinedAPI04() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntStream.range(0, a.getSize()).sequential().forEach(i -> a.set(i, i));
        IntStream.range(0, b.getSize()).sequential().forEach(i -> b.set(i, i));
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler();
        gridScheduler.addWorkerGrid("s04.t0", (WorkerGrid)worker);
        gridScheduler.addWorkerGrid("s04.t1", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s04").transferToDevice(1, new Object[]{a, b}).task("t0", TestCombinedTaskGraph::vectorAddV2, (Object)context, (Object)a, (Object)b, (Object)cTornado).task("t1", TestCombinedTaskGraph::vectorMulV2, (Object)context, (Object)cTornado, (Object)b, (Object)cTornado).task("t2", TestCombinedTaskGraph::vectorSubV1, (Object)cTornado, (Object)b, (Object)cTornado).transferToHost(1, new Object[]{cTornado});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestCombinedTaskGraph.vectorAddV1(a, b, cJava);
        TestCombinedTaskGraph.vectorMulV1(cJava, b, cJava);
        TestCombinedTaskGraph.vectorSubV1(cJava, b, cJava);
        for (int i2 = 0; i2 < 16; ++i2) {
            Assert.assertEquals((long)cJava.get(i2), (long)cTornado.get(i2));
        }
    }

    @Test
    public void combinedAPI05() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntStream.range(0, a.getSize()).sequential().forEach(i -> a.set(i, i));
        IntStream.range(0, b.getSize()).sequential().forEach(i -> b.set(i, i));
        WorkerGrid1D workerT0 = new WorkerGrid1D(16);
        WorkerGrid1D workerT1 = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler();
        gridScheduler.addWorkerGrid("s05.t0", (WorkerGrid)workerT0);
        gridScheduler.addWorkerGrid("s05.t1", (WorkerGrid)workerT1);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s05").transferToDevice(1, new Object[]{a, b}).task("t0", TestCombinedTaskGraph::vectorAddV2, (Object)context, (Object)a, (Object)b, (Object)cTornado).task("t1", TestCombinedTaskGraph::vectorMulV2, (Object)context, (Object)cTornado, (Object)b, (Object)cTornado).task("t2", TestCombinedTaskGraph::vectorSubV1, (Object)cTornado, (Object)b, (Object)cTornado).transferToHost(1, new Object[]{cTornado});
        workerT0.setGlobalWork(16L, 1L, 1L);
        workerT0.setLocalWork(8L, 1L, 1L);
        workerT1.setGlobalWork(16L, 1L, 1L);
        workerT1.setLocalWorkToNull();
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestCombinedTaskGraph.vectorAddV1(a, b, cJava);
        TestCombinedTaskGraph.vectorMulV1(cJava, b, cJava);
        TestCombinedTaskGraph.vectorSubV1(cJava, b, cJava);
        for (int i2 = 0; i2 < 16; ++i2) {
            Assert.assertEquals((long)cJava.get(i2), (long)cTornado.get(i2));
        }
    }
}

