/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.virtualization;

import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoBackend;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;
import uk.ac.manchester.tornado.unittests.common.TornadoVMMultiDeviceNotSupported;

public class TestsVirtualLayer
extends TornadoTestBase {
    public static void accumulator(IntArray a, int value) {
        for (int i = 0; i < a.getSize(); ++i) {
            a.set(i, a.get(i) + value);
        }
    }

    public static void saxpy(float alpha, FloatArray x, FloatArray y) {
        for (int i = 0; i < y.getSize(); ++i) {
            y.set(i, alpha * x.get(i));
        }
    }

    public static void testA(IntArray a, int value) {
        for (int i = 0; i < a.getSize(); ++i) {
            a.set(i, a.get(i) + value);
        }
    }

    public static void testB(IntArray a, int value) {
        for (int i = 0; i < a.getSize(); ++i) {
            a.set(i, a.get(i) * value);
        }
    }

    @Before
    public void enoughDevices() {
        super.before();
        TornadoBackend driver = TestsVirtualLayer.getTornadoRuntime().getBackend(0);
        if (driver.getNumDevices() < 2) {
            throw new TornadoVMMultiDeviceNotSupported("Not enough devices to run tests");
        }
    }

    @Test
    public void testDevices() {
        TornadoBackend driver = TestsVirtualLayer.getTornadoRuntime().getBackend(0);
        Assert.assertNotNull((Object)driver.getDevice(0));
        Assert.assertNotNull((Object)driver.getDevice(1));
    }

    @Test
    public void testDriverAndDevices() {
        int numDrivers = TestsVirtualLayer.getTornadoRuntime().getNumBackends();
        for (int i = 0; i < numDrivers; ++i) {
            TornadoBackend driver = TestsVirtualLayer.getTornadoRuntime().getBackend(i);
            Assert.assertNotNull((Object)driver);
            int numDevices = driver.getNumDevices();
            for (int j = 0; j < numDevices; ++j) {
                Assert.assertNotNull((Object)driver.getDevice(j));
            }
        }
    }

    @Test
    public void testArrayMigration() throws TornadoExecutionPlanException {
        int numElements = 8;
        boolean numKernels = true;
        IntArray data = new IntArray(8);
        int initValue = 0;
        TaskGraph taskGraph = new TaskGraph("s0");
        for (int i = 0; i < 1; ++i) {
            taskGraph.task("t" + i, TestsVirtualLayer::accumulator, (Object)data, (Object)1);
        }
        taskGraph.transferToHost(1, new Object[]{data});
        TornadoBackend driver = TestsVirtualLayer.getTornadoRuntime().getBackend(0);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withDevice(driver.getDevice(0));
            executionPlan.execute();
            for (int i = 0; i < 8; ++i) {
                Assert.assertEquals((long)(initValue + 1), (long)data.get(i));
            }
            ++initValue;
            executionPlan.withDevice(driver.getDevice(1));
            executionPlan.execute();
        }
        for (int i = 0; i < 8; ++i) {
            Assert.assertEquals((long)(initValue + 1), (long)data.get(i));
        }
    }

    @Test
    public void testTaskMigration() throws TornadoExecutionPlanException {
        TornadoBackend driver = TestsVirtualLayer.getTornadoRuntime().getBackend(0);
        int numElements = 512;
        float alpha = 2.0f;
        FloatArray x = new FloatArray(512);
        FloatArray y = new FloatArray(512);
        IntStream.range(0, 512).parallel().forEach(i -> x.set(i, 450.0f));
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{x}).task("t0", TestsVirtualLayer::saxpy, (Object)Float.valueOf(2.0f), (Object)x, (Object)y).transferToHost(1, new Object[]{y}).transferToHost(1, new Object[]{y});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withDevice(driver.getDevice(0)).execute();
            for (int i2 = 0; i2 < 512; ++i2) {
                Assert.assertEquals((float)900.0f, (float)y.get(i2), (float)0.001f);
            }
            executionPlan.withDevice(driver.getDevice(1)).execute();
        }
        for (int i3 = 0; i3 < 512; ++i3) {
            Assert.assertEquals((float)900.0f, (float)y.get(i3), (float)0.001f);
        }
    }

    @Ignore
    public void testVirtualLayer01() throws TornadoExecutionPlanException {
        TornadoBackend driver = TestsVirtualLayer.getTornadoRuntime().getBackend(0);
        int N = 128;
        IntArray data = new IntArray(128);
        data.init(100);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{data}).task("t0", TestsVirtualLayer::testA, (Object)data, (Object)1).transferToHost(1, new Object[]{data});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withDevice(driver.getDevice(0)).execute();
            executionPlan.withDevice(driver.getDevice(1)).execute();
            taskGraph.transferToDevice(0, new Object[]{data}).task("t1", TestsVirtualLayer::testA, (Object)data, (Object)10).transferToHost(1, new Object[]{data});
            executionPlan.execute();
        }
    }

    @Ignore
    public void testVirtualLayer02() throws TornadoExecutionPlanException {
        TornadoBackend driver = TestsVirtualLayer.getTornadoRuntime().getBackend(0);
        int N = 128;
        IntArray data = new IntArray(128);
        data.init(100);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{data}).task("t0", TestsVirtualLayer::testA, (Object)data, (Object)1);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withDevice(driver.getDevice(0)).execute();
            executionPlan.withDevice(driver.getDevice(1)).execute();
            taskGraph.transferToDevice(0, new Object[]{data}).task("t1", TestsVirtualLayer::testA, (Object)data, (Object)10).transferToHost(1, new Object[]{data});
            executionPlan.withDevice(driver.getDevice(0)).execute();
        }
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((long)111L, (long)data.get(i));
        }
    }

    @Test
    public void testVirtualLayer03() throws TornadoExecutionPlanException {
        int N = 128;
        IntArray dataA = new IntArray(128);
        IntArray dataB = new IntArray(128);
        dataA.init(100);
        dataB.init(200);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{dataA, dataB}).task("t0", TestsVirtualLayer::testA, (Object)dataA, (Object)1).task("t1", TestsVirtualLayer::testA, (Object)dataB, (Object)10).transferToHost(1, new Object[]{dataA}).transferToHost(1, new Object[]{dataB});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((long)101L, (long)dataA.get(i));
            Assert.assertEquals((long)210L, (long)dataB.get(i));
        }
    }

    @Test
    public void testDynamicDeviceSwitch() throws TornadoExecutionPlanException {
        int N = 128;
        IntArray data = new IntArray(128);
        data.init(100);
        int totalNumDevices = 0;
        int numDrivers = TestsVirtualLayer.getTornadoRuntime().getNumBackends();
        for (int driverIndex = 0; driverIndex < numDrivers; ++driverIndex) {
            String taskScheduleName = "s" + driverIndex;
            TaskGraph taskGraph = new TaskGraph(taskScheduleName);
            TornadoBackend driver = TestsVirtualLayer.getTornadoRuntime().getBackend(driverIndex);
            int numDevices = driver.getNumDevices();
            totalNumDevices += numDevices;
            String taskName = "t0";
            taskGraph.transferToDevice(0, new Object[]{data}).task(taskName, TestsVirtualLayer::testA, (Object)data, (Object)1).transferToHost(1, new Object[]{data});
            ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
            try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
                for (int deviceIndex = 0; deviceIndex < numDevices; ++deviceIndex) {
                    executionPlan.withDevice(driver.getDevice(deviceIndex)).execute();
                }
                continue;
            }
        }
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((long)(100 + totalNumDevices), (long)data.get(i));
        }
    }

    @Test
    public void testSchedulerDevices() throws TornadoExecutionPlanException {
        TornadoBackend tornadoDriver = TestsVirtualLayer.getTornadoRuntime().getBackend(0);
        int N = 128;
        IntArray dataA = new IntArray(128);
        IntArray dataB = new IntArray(128);
        dataA.init(100);
        dataB.init(100);
        if (tornadoDriver.getNumDevices() < 2) {
            Assert.fail((String)"The current driver has less than 2 devices");
        }
        TornadoRuntimeProvider.setProperty((String)"s0.t0.device", (String)"0:0");
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{dataA, dataB}).task("t0", TestsVirtualLayer::testA, (Object)dataA, (Object)1).transferToHost(1, new Object[]{dataA});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        TornadoRuntimeProvider.setProperty((String)"s1.t1.device", (String)"0:1");
        TaskGraph taskGraph2 = new TaskGraph("s1").task("t1", TestsVirtualLayer::testA, (Object)dataB, (Object)1).transferToHost(1, new Object[]{dataB});
        ImmutableTaskGraph immutableTaskGraph2 = taskGraph2.snapshot();
        try (TornadoExecutionPlan executionPlan2 = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph2});){
            executionPlan2.execute();
        }
        for (int i = 0; i < 128; ++i) {
            Assert.assertEquals((long)dataA.get(i), (long)dataB.get(i));
        }
    }
}

