/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.vm.concurrency;

import java.util.Random;
import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;
import uk.ac.manchester.tornado.unittests.tools.Exceptions.UnsupportedConfigurationException;

public class TestParallelTaskGraph
extends TornadoTestBase {
    final int SIZE = 1024;

    public static void init(FloatArray a) {
        for (int i = 0; i < a.getSize(); ++i) {
            a.set(i, (float)i);
        }
    }

    public static void multiply(FloatArray a, float alpha) {
        for (int i = 0; i < a.getSize(); ++i) {
            float temp = a.get(i) * (float)i + alpha;
            a.set(i, temp);
        }
    }

    @Test
    public void testTwoDevicesSerial() throws TornadoExecutionPlanException {
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray refB = new FloatArray(1024);
        FloatArray refA = new FloatArray(1024);
        float alpha = 0.12f;
        Random r = new Random(31L);
        IntStream.range(0, 1024).forEach(i -> {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
            refA.set(i, a.get(i));
            refB.set(i, b.get(i));
        });
        TaskGraph taskGraph = new TaskGraph("graph").transferToDevice(0, new Object[]{a, b}).task("task0", TestParallelTaskGraph::init, (Object)a).task("task1", TestParallelTaskGraph::multiply, (Object)b, (Object)Float.valueOf(alpha)).transferToHost(1, new Object[]{a, b});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            int deviceCount = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getNumDevices();
            if (deviceCount < 2) {
                throw new UnsupportedConfigurationException("Test requires at least two devices");
            }
            TornadoDevice device0 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(0);
            TornadoDevice device1 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(1);
            executionPlan.withDevice("graph.task0", device0).withDevice("graph.task1", device1);
            executionPlan.execute();
        }
        for (int i2 = 0; i2 < a.getSize(); ++i2) {
            Assert.assertEquals((float)i2, (float)a.get(i2), (float)0.001f);
            Assert.assertEquals((float)(refB.get(i2) * (float)i2 + alpha), (float)b.get(i2), (float)0.001f);
        }
    }

    @Test
    public void testTwoDevicesSerial1() throws TornadoExecutionPlanException {
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray refB = new FloatArray(1024);
        FloatArray refA = new FloatArray(1024);
        float alpha = 0.12f;
        Random r = new Random(31L);
        IntStream.range(0, 1024).forEach(i -> {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
            refA.set(i, a.get(i));
            refB.set(i, b.get(i));
        });
        TaskGraph taskGraph = new TaskGraph("graph").transferToDevice(0, new Object[]{a, b}).task("task0", TestParallelTaskGraph::init, (Object)a).task("task1", TestParallelTaskGraph::multiply, (Object)b, (Object)Float.valueOf(alpha)).transferToHost(1, new Object[]{a, b});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            int deviceCount = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getNumDevices();
            if (deviceCount < 3) {
                throw new UnsupportedConfigurationException("Test requires at least three devices");
            }
            TornadoDevice device0 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(0);
            TornadoDevice device1 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(1);
            TornadoDevice device2 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(2);
            executionPlan.withDevice("graph.task0", device0).withDevice("graph.task1", device1);
            executionPlan.execute();
            executionPlan.withDevice("graph.task0", device2).withDevice("graph.task1", device2);
            executionPlan.execute();
        }
        TestParallelTaskGraph.multiply(refB, alpha);
        TestParallelTaskGraph.multiply(refB, alpha);
        for (int i2 = 0; i2 < a.getSize(); ++i2) {
            Assert.assertEquals((float)i2, (float)a.get(i2), (float)0.001f);
            Assert.assertEquals((float)refB.get(i2), (float)b.get(i2), (float)0.5f);
        }
    }

    @Test
    public void testTwoDevicesSerial2() throws TornadoExecutionPlanException {
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray refB = new FloatArray(1024);
        FloatArray refA = new FloatArray(1024);
        float alpha = 0.12f;
        Random r = new Random(31L);
        IntStream.range(0, 1024).forEach(i -> {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
            refA.set(i, a.get(i));
            refB.set(i, b.get(i));
        });
        TaskGraph taskGraph = new TaskGraph("graph").transferToDevice(1, new Object[]{a, b}).task("task0", TestParallelTaskGraph::init, (Object)a).task("task1", TestParallelTaskGraph::multiply, (Object)b, (Object)Float.valueOf(alpha)).transferToHost(1, new Object[]{a, b});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            int deviceCount = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getNumDevices();
            if (deviceCount < 2) {
                throw new UnsupportedConfigurationException("Test requires at least two devices");
            }
            TornadoDevice device0 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(0);
            TornadoDevice device1 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(1);
            executionPlan.withDevice("graph.task0", device0).withDevice("graph.task1", device1);
            executionPlan.execute();
            executionPlan.withDevice("graph.task0", device1).withDevice("graph.task1", device0);
            executionPlan.execute();
        }
        TestParallelTaskGraph.multiply(refB, alpha);
        TestParallelTaskGraph.multiply(refB, alpha);
        for (int i2 = 0; i2 < a.getSize(); ++i2) {
            Assert.assertEquals((float)i2, (float)a.get(i2), (float)0.001f);
            Assert.assertEquals((float)refB.get(i2), (float)b.get(i2), (float)0.5f);
        }
    }

    @Test
    public void testTwoDevicesConcurrent() throws TornadoExecutionPlanException {
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray refB = new FloatArray(1024);
        FloatArray refA = new FloatArray(1024);
        float alpha = 0.12f;
        Random r = new Random(31L);
        IntStream.range(0, 1024).forEach(i -> {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
            refA.set(i, a.get(i));
            refB.set(i, b.get(i));
        });
        TaskGraph taskGraph = new TaskGraph("graph").transferToDevice(0, new Object[]{a, b}).task("task0", TestParallelTaskGraph::init, (Object)a).task("task1", TestParallelTaskGraph::multiply, (Object)b, (Object)Float.valueOf(alpha)).transferToHost(1, new Object[]{a, b});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            int deviceCount = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getNumDevices();
            if (deviceCount < 2) {
                throw new UnsupportedConfigurationException("Test requires at least two devices");
            }
            TornadoDevice device0 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(0);
            TornadoDevice device1 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(1);
            executionPlan.withConcurrentDevices().withDevice("graph.task0", device0).withDevice("graph.task1", device1);
            executionPlan.execute();
        }
        for (int i2 = 0; i2 < a.getSize(); ++i2) {
            Assert.assertEquals((float)i2, (float)a.get(i2), (float)0.001f);
            Assert.assertEquals((float)(refB.get(i2) * (float)i2 + alpha), (float)b.get(i2), (float)0.001f);
        }
    }

    @Test
    public void testTwoDevicesConcurrentOnAndOff() throws TornadoExecutionPlanException {
        FloatArray a = new FloatArray(1024);
        FloatArray b = new FloatArray(1024);
        FloatArray refB = new FloatArray(1024);
        FloatArray refA = new FloatArray(1024);
        float alpha = 0.12f;
        Random r = new Random(31L);
        IntStream.range(0, 1024).forEach(i -> {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
            refA.set(i, a.get(i));
            refB.set(i, b.get(i));
        });
        TaskGraph taskGraph = new TaskGraph("graph").transferToDevice(0, new Object[]{a, b}).task("task0", TestParallelTaskGraph::init, (Object)a).task("task1", TestParallelTaskGraph::multiply, (Object)b, (Object)Float.valueOf(alpha)).transferToHost(1, new Object[]{a, b});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            int deviceCount = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getNumDevices();
            if (deviceCount < 2) {
                throw new UnsupportedConfigurationException("Test requires at least two devices");
            }
            TornadoDevice device0 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(0);
            TornadoDevice device1 = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(1);
            executionPlan.withConcurrentDevices().withDevice("graph.task0", device0).withDevice("graph.task1", device1);
            executionPlan.execute();
            executionPlan.withoutConcurrentDevices().execute();
        }
        TestParallelTaskGraph.multiply(refB, alpha);
        TestParallelTaskGraph.multiply(refB, alpha);
        for (int i2 = 0; i2 < a.getSize(); ++i2) {
            Assert.assertEquals((float)i2, (float)a.get(i2), (float)0.001f);
            Assert.assertEquals((float)refB.get(i2), (float)b.get(i2), (float)0.5f);
        }
    }
}

