/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.api;

import java.util.Random;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestChainOfGridSchedulers
extends TornadoTestBase {
    public static void vectorAdd(KernelContext context, FloatArray a, FloatArray b, FloatArray c) {
        int idx = context.globalIdx;
        if (idx < c.getSize()) {
            c.set(idx, a.get(idx) + b.get(idx));
        }
    }

    public static void vectorMul(KernelContext context, FloatArray a, FloatArray b, FloatArray c) {
        int idx = context.globalIdx;
        if (idx < c.getSize()) {
            c.set(idx, a.get(idx) * b.get(idx));
        }
    }

    @Test
    public void testMultipleTaskGraphs() throws TornadoExecutionPlanException {
        int size = 1024;
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray c1 = new FloatArray(1024);
        FloatArray c2 = new FloatArray(1024);
        Random r = new Random(71L);
        for (int i = 0; i < 1024; ++i) {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
        }
        KernelContext context = new KernelContext();
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("vectorAdd", TestChainOfGridSchedulers::vectorAdd, (Object)context, (Object)a, (Object)b, (Object)c1).transferToHost(1, new Object[]{c1});
        TaskGraph tg2 = new TaskGraph("s1").transferToDevice(0, new Object[]{a, b}).task("vectorMul", TestChainOfGridSchedulers::vectorMul, (Object)context, (Object)a, (Object)b, (Object)c2).transferToHost(1, new Object[]{c2});
        WorkerGrid1D worker1 = new WorkerGrid1D(1024);
        worker1.setLocalWork(256L, 1L, 1L);
        WorkerGrid1D worker2 = new WorkerGrid1D(1024);
        worker2.setLocalWork(128L, 1L, 1L);
        GridScheduler grid = new GridScheduler();
        grid.addWorkerGrid("s0.vectorAdd", (WorkerGrid)worker1);
        grid.addWorkerGrid("s1.vectorMul", (WorkerGrid)worker2);
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot()});){
            executionPlan.withGridScheduler(grid).withGraph(0).execute();
            executionPlan.withGridScheduler(grid).withGraph(1).execute();
        }
        for (int i = 0; i < 1024; ++i) {
            float expected = a.get(i) + b.get(i);
            float actual = c1.get(i);
            Assert.assertEquals((float)expected, (float)actual, (float)1.0E-6f);
            expected = a.get(i) * b.get(i);
            actual = c2.get(i);
            Assert.assertEquals((float)expected, (float)actual, (float)1.0E-6f);
        }
    }

    @Test
    public void testMultipleTaskGraphsSchedulerReverse() throws TornadoExecutionPlanException {
        int size = 1024;
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray c1 = new FloatArray(1024);
        FloatArray c2 = new FloatArray(1024);
        c1.init(0.0f);
        c2.init(0.0f);
        Random r = new Random();
        for (int i = 0; i < 1024; ++i) {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
        }
        KernelContext context = new KernelContext();
        TaskGraph tg1 = new TaskGraph("s0").transferToDevice(0, new Object[]{a, b}).task("vectorAdd", TestChainOfGridSchedulers::vectorAdd, (Object)context, (Object)a, (Object)b, (Object)c1).transferToHost(1, new Object[]{c1});
        TaskGraph tg2 = new TaskGraph("s1").transferToDevice(0, new Object[]{a, b}).task("vectorMul", TestChainOfGridSchedulers::vectorMul, (Object)context, (Object)a, (Object)b, (Object)c2).transferToHost(1, new Object[]{c2});
        WorkerGrid1D worker1 = new WorkerGrid1D(1024);
        worker1.setLocalWork(256L, 1L, 1L);
        WorkerGrid1D worker2 = new WorkerGrid1D(1024);
        worker2.setLocalWork(128L, 1L, 1L);
        GridScheduler grid = new GridScheduler();
        grid.addWorkerGrid("s0.vectorAdd", (WorkerGrid)worker1);
        grid.addWorkerGrid("s1.vectorMul", (WorkerGrid)worker2);
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{tg1.snapshot(), tg2.snapshot()});){
            executionPlan.withGraph(0).withGridScheduler(grid).execute();
            executionPlan.withGridScheduler(grid).withGraph(1).execute();
        }
        for (int i = 0; i < 1024; ++i) {
            float expected = a.get(i) + b.get(i);
            float actual = c1.get(i);
            Assert.assertEquals((float)expected, (float)actual, (float)0.001f);
            expected = a.get(i) * b.get(i);
            actual = c2.get(i);
            Assert.assertEquals((float)expected, (float)actual, (float)0.001f);
        }
    }
}

