/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.grid;

import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

public class TestGridScheduler {
    public static float computeSequential(FloatArray a, FloatArray b, FloatArray c) {
        float acc = 0.0f;
        TestGridScheduler.vectorAddFloat(a, b, c);
        for (int i = 0; i < c.getSize(); ++i) {
            acc += c.get(i);
        }
        return acc;
    }

    public static void vectorAddFloat(FloatArray a, FloatArray b, FloatArray c) {
        for (int i = 0; i < c.getSize(); ++i) {
            c.set(i, a.get(i) + b.get(i));
        }
    }

    public static void reduceAdd(FloatArray array, int size) {
        float acc = array.get(0);
        for (int i = 1; i < size; ++i) {
            acc += array.get(i);
        }
        array.set(0, acc);
    }

    @Test
    public void testMultipleTasksWithinTaskGraph() throws TornadoExecutionPlanException {
        int size = 1024;
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray sequentialC = new FloatArray(1024);
        FloatArray tornadoC = new FloatArray(1024);
        IntStream.range(0, a.getSize()).sequential().forEach(i -> a.set(i, (float)i));
        IntStream.range(0, b.getSize()).sequential().forEach(i -> b.set(i, 2.0f));
        float sequential = TestGridScheduler.computeSequential(a, b, sequentialC);
        WorkerGrid1D worker = new WorkerGrid1D(1024);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b, 1024}).task("t0", TestGridScheduler::vectorAddFloat, (Object)a, (Object)b, (Object)tornadoC).task("t1", TestGridScheduler::reduceAdd, (Object)tornadoC, (Object)1024).transferToHost(1, new Object[]{tornadoC});
        worker.setGlobalWork(1024L, 1L, 1L);
        worker.setLocalWork(1L, 1L, 1L);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        float finalSum = tornadoC.get(0);
        Assert.assertEquals((float)sequential, (float)finalSum, (float)0.0f);
    }

    @Test
    public void testMultiTaskGraphs() throws TornadoExecutionPlanException {
        int size = 1024;
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray output = new FloatArray(1024);
        IntStream.range(0, a.getSize()).sequential().forEach(i -> a.set(i, (float)i));
        IntStream.range(0, b.getSize()).sequential().forEach(i -> b.set(i, 2.0f));
        WorkerGrid1D worker = new WorkerGrid1D(1024);
        worker.setLocalWork(1L, 1L, 1L);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b, 1024}).task("t0", TestGridScheduler::vectorAddFloat, (Object)a, (Object)b, (Object)output).transferToHost(1, new Object[]{output});
        TaskGraph tg2 = new TaskGraph("s1").transferToDevice(1, new Object[]{a, b, 1024}).task("t1", TestGridScheduler::vectorAddFloat, (Object)a, (Object)b, (Object)output).task("t2", TestGridScheduler::vectorAddFloat, (Object)a, (Object)b, (Object)output).transferToHost(1, new Object[]{output});
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot()});){
            executionPlan.withGraph(0).withGridScheduler(gridScheduler).execute();
            executionPlan.withGraph(1).execute();
        }
    }

    @Test
    public void testMultipleTasksSeparateTaskGraphs() throws TornadoExecutionPlanException {
        int size = 1024;
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray sequentialC = new FloatArray(1024);
        FloatArray tornadoC = new FloatArray(1024);
        IntStream.range(0, a.getSize()).sequential().forEach(i -> a.set(i, (float)i));
        IntStream.range(0, b.getSize()).sequential().forEach(i -> b.set(i, 2.0f));
        float sequential = TestGridScheduler.computeSequential(a, b, sequentialC);
        WorkerGrid1D worker = new WorkerGrid1D(1024);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        TaskGraph s0 = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b, 1024}).task("t0", TestGridScheduler::vectorAddFloat, (Object)a, (Object)b, (Object)tornadoC).transferToHost(1, new Object[]{tornadoC});
        worker.setGlobalWork(1024L, 1L, 1L);
        worker.setLocalWork(1L, 1L, 1L);
        ImmutableTaskGraph immutableTaskGraph = s0.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TaskGraph s1 = new TaskGraph("s1").transferToDevice(1, new Object[]{tornadoC, 1024}).task("t0", TestGridScheduler::reduceAdd, (Object)tornadoC, (Object)1024).transferToHost(1, new Object[]{tornadoC});
        ImmutableTaskGraph immutableTaskGraph1 = s1.snapshot();
        try (TornadoExecutionPlan executionPlan1 = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph1});){
            executionPlan1.execute();
        }
        float finalSum = tornadoC.get(0);
        Assert.assertEquals((float)sequential, (float)finalSum, (float)0.0f);
    }
}

