/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.multithreaded;

import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.enums.ProfilerMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestMultiThreadedExecutionPlans
extends TornadoTestBase {
    private static void computeForThread1(FloatArray input, FloatArray output, KernelContext context) {
        int idx = context.globalIdx;
        float value = input.get(idx) * 100.0f * TornadoMath.sqrt((float)input.get(idx));
        output.set(idx, value);
    }

    private static void computeForThread2(FloatArray input, FloatArray output) {
        for (int i = 0; i < input.getSize(); ++i) {
            float value = input.get(i) * 100.0f * TornadoMath.sqrt((float)input.get(i));
            output.set(i, value);
        }
    }

    @Test
    public void test01() throws InterruptedException {
        KernelContext context = new KernelContext();
        int size = 0x4000000;
        FloatArray input = new FloatArray(0x4000000);
        FloatArray output = new FloatArray(0x4000000);
        input.init(1.0f);
        TaskGraph taskGraph = new TaskGraph("check").transferToDevice(1, new Object[]{input}).task("compute01", TestMultiThreadedExecutionPlans::computeForThread1, (Object)input, (Object)output, (Object)context).transferToHost(1, new Object[]{output});
        WorkerGrid1D workerGrid = new WorkerGrid1D(0x4000000);
        GridScheduler gridScheduler = new GridScheduler("check.compute01", (WorkerGrid)workerGrid);
        Thread t0 = new Thread(() -> {
            System.out.print("Running thread t0");
            ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
            try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
                executionPlan.withGridScheduler(gridScheduler).execute();
            }
            catch (TornadoExecutionPlanException e) {
                throw new RuntimeException(e);
            }
        });
        Thread t1 = new Thread(() -> {
            System.out.print("Running thread t1");
            ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
            try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
                executionPlan.withGridScheduler(gridScheduler).execute();
            }
            catch (TornadoExecutionPlanException e) {
                throw new RuntimeException(e);
            }
        });
        t0.start();
        t1.start();
        t0.join();
        t1.join();
    }

    @Test
    public void test02() throws InterruptedException {
        int size = 0x2000000;
        FloatArray input = new FloatArray(0x2000000);
        input.init(1.0f);
        FloatArray output = new FloatArray(0x2000000);
        TaskGraph taskGraph = new TaskGraph("check").transferToDevice(1, new Object[]{input}).task("compute01", TestMultiThreadedExecutionPlans::computeForThread2, (Object)input, (Object)output).transferToHost(1, new Object[]{output});
        Thread t0 = new Thread(() -> {
            ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
            try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
                executionPlan.execute();
            }
            catch (TornadoExecutionPlanException e) {
                throw new RuntimeException(e);
            }
        });
        Thread t1 = new Thread(() -> {
            ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
            try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
                executionPlan.execute();
            }
            catch (TornadoExecutionPlanException e) {
                throw new RuntimeException(e);
            }
        });
        t0.start();
        t1.start();
        t0.join();
        t1.join();
    }

    private void compute(int size, int id, boolean profiling) throws TornadoExecutionPlanException {
        FloatArray input = new FloatArray(size);
        input.init(1.0f);
        FloatArray output = new FloatArray(size);
        TaskGraph taskGraph = new TaskGraph("loop" + (id + 100)).transferToDevice(1, new Object[]{input}).task("compute01", TestMultiThreadedExecutionPlans::computeForThread2, (Object)input, (Object)output).transferToHost(1, new Object[]{output});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            if (profiling) {
                executionPlan.withProfiler(ProfilerMode.SILENT);
            }
            executionPlan.execute();
        }
    }

    @Test
    public void test03() {
        for (int i = 0; i < 100; ++i) {
            int finalI = i;
            Thread t1 = new Thread(() -> {
                try {
                    this.compute(1038336, finalI, false);
                }
                catch (TornadoExecutionPlanException e) {
                    throw new RuntimeException(e);
                }
            });
            Thread t2 = new Thread(() -> {
                try {
                    this.compute(1038336, finalI + 100, false);
                }
                catch (TornadoExecutionPlanException e) {
                    throw new RuntimeException(e);
                }
            });
            t1.start();
            t2.start();
            try {
                t1.join();
            }
            catch (InterruptedException e) {
                Assert.fail((String)"Error");
            }
            try {
                t2.join();
                continue;
            }
            catch (InterruptedException e) {
                Assert.fail((String)"Error");
            }
        }
    }

    @Test
    public void test04() {
        for (int i = 0; i < 100; ++i) {
            int finalI = i;
            Thread t1 = new Thread(() -> {
                try {
                    this.compute(66453504, finalI, true);
                }
                catch (TornadoExecutionPlanException e) {
                    throw new RuntimeException(e);
                }
            });
            Thread t2 = new Thread(() -> {
                try {
                    this.compute(66453504, finalI + 100, true);
                }
                catch (TornadoExecutionPlanException e) {
                    throw new RuntimeException(e);
                }
            });
            t1.start();
            t2.start();
            try {
                t1.join();
            }
            catch (InterruptedException e) {
                Assert.fail((String)"Error");
            }
            try {
                t2.join();
                continue;
            }
            catch (InterruptedException e) {
                Assert.fail((String)"Error");
            }
        }
    }
}

