/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.kernelcontext.api;

import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestVectorAdditionKernelContext
extends TornadoTestBase {
    public static void vectorAddJava(IntArray a, IntArray b, IntArray c) {
        for (int i = 0; i < c.getSize(); ++i) {
            c.set(i, a.get(i) + b.get(i));
        }
    }

    public static void vectorAdd(KernelContext context, IntArray a, IntArray b, IntArray c) {
        c.set(context.globalIdx.intValue(), a.get(context.globalIdx.intValue()) + b.get(context.globalIdx.intValue()));
    }

    public static void vectorAdd(IntArray a, KernelContext context, IntArray b, IntArray c) {
        c.set(context.globalIdx.intValue(), a.get(context.globalIdx.intValue()) + b.get(context.globalIdx.intValue()));
    }

    public static void vectorAdd(IntArray a, IntArray b, KernelContext context, IntArray c) {
        c.set(context.globalIdx.intValue(), a.get(context.globalIdx.intValue()) + b.get(context.globalIdx.intValue()));
    }

    public static void vectorAdd(IntArray a, IntArray b, IntArray c, KernelContext context) {
        c.set(context.globalIdx.intValue(), a.get(context.globalIdx.intValue()) + b.get(context.globalIdx.intValue()));
    }

    @Test
    public void vectorAddKernelContext01() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        a.init(10);
        b.init(20);
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestVectorAdditionKernelContext::vectorAdd, (Object)context, (Object)a, (Object)b, (Object)cTornado).transferToHost(1, new Object[]{cTornado});
        worker.setGlobalWork(16L, 1L, 1L);
        worker.setLocalWorkToNull();
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestVectorAdditionKernelContext.vectorAddJava(a, b, cJava);
        for (int i = 0; i < 16; ++i) {
            Assert.assertEquals((long)cJava.get(i), (long)cTornado.get(i));
        }
    }

    @Test
    public void vectorAddKernelContext02() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        a.init(10);
        b.init(20);
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler();
        gridScheduler.addWorkerGrid("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestVectorAdditionKernelContext::vectorAdd, (Object)a, (Object)context, (Object)b, (Object)cTornado).transferToHost(1, new Object[]{cTornado});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestVectorAdditionKernelContext.vectorAddJava(a, b, cJava);
        for (int i = 0; i < 16; ++i) {
            Assert.assertEquals((long)cJava.get(i), (long)cTornado.get(i));
        }
    }

    @Test
    public void vectorAddKernelContext03() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        a.init(10);
        b.init(20);
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler();
        gridScheduler.addWorkerGrid("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestVectorAdditionKernelContext::vectorAdd, (Object)a, (Object)b, (Object)context, (Object)cTornado).transferToHost(1, new Object[]{cTornado});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestVectorAdditionKernelContext.vectorAddJava(a, b, cJava);
        for (int i = 0; i < 16; ++i) {
            Assert.assertEquals((long)cJava.get(i), (long)cTornado.get(i));
        }
    }

    @Test
    public void vectorAddKernelContext04() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        a.init(10);
        b.init(20);
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler();
        gridScheduler.addWorkerGrid("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestVectorAdditionKernelContext::vectorAdd, (Object)a, (Object)b, (Object)cTornado, (Object)context).transferToHost(1, new Object[]{cTornado});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestVectorAdditionKernelContext.vectorAddJava(a, b, cJava);
        for (int i = 0; i < 16; ++i) {
            Assert.assertEquals((long)cJava.get(i), (long)cTornado.get(i));
        }
    }

    @Test
    public void vectorAddKernelContext05() throws TornadoExecutionPlanException {
        int size = 16;
        IntArray a = new IntArray(16);
        IntArray b = new IntArray(16);
        IntArray cJava = new IntArray(16);
        IntArray cTornado = new IntArray(16);
        a.init(10);
        b.init(20);
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestVectorAdditionKernelContext::vectorAdd, (Object)context, (Object)a, (Object)b, (Object)cTornado).transferToHost(1, new Object[]{cTornado});
        worker.setGlobalWork(16L, 1L, 1L);
        worker.setLocalWorkToNull();
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withPreCompilation().withGridScheduler(gridScheduler).execute();
        }
        TestVectorAdditionKernelContext.vectorAddJava(a, b, cJava);
        for (int i = 0; i < 16; ++i) {
            Assert.assertEquals((long)cJava.get(i), (long)cTornado.get(i));
        }
    }
}

