/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.grid;

import java.util.Random;
import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.WorkerGrid2D;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.matrix.Matrix2DInt;
import uk.ac.manchester.tornado.unittests.arrays.TestArrays;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;
import uk.ac.manchester.tornado.unittests.matrices.TestMatrixTypes;

public class TestGrid
extends TornadoTestBase {
    final int NUM_ELEMENTS = 4096;

    private static void matrixMultiplication(FloatArray A2, FloatArray B2, FloatArray C, int size) {
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                float sum = 0.0f;
                for (int k = 0; k < size; ++k) {
                    sum += A2.get(i * size + k) * B2.get(k * size + j);
                }
                C.set(i * size + j, sum);
            }
        }
    }

    private static void testMatrixIntegers(int X, int Y) throws TornadoExecutionPlanException {
        int[][] a = new int[X][Y];
        Random r = new Random();
        for (int i = 0; i < X; ++i) {
            for (int j = 0; j < Y; ++j) {
                a[i][j] = r.nextInt();
            }
        }
        Matrix2DInt matrixA = new Matrix2DInt(a);
        Matrix2DInt matrixB = new Matrix2DInt(X, Y);
        TaskGraph taskGraph = new TaskGraph("foo").transferToDevice(0, new Object[]{matrixA}).task("bar", TestMatrixTypes::computeMatrixSum, (Object)matrixA, (Object)matrixB, (Object)X, (Object)Y).transferToHost(1, new Object[]{matrixB});
        WorkerGrid2D worker = new WorkerGrid2D(X, Y);
        GridScheduler gridScheduler = new GridScheduler("foo.bar", (WorkerGrid)worker);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executor.withGridScheduler(gridScheduler).execute();
        }
        for (int i = 0; i < X; ++i) {
            for (int j = 0; j < Y; ++j) {
                Assert.assertEquals((long)(matrixA.get(i, j) + matrixA.get(i, j)), (long)matrixB.get(i, j));
            }
        }
    }

    @Test
    public void testDynamicGrid01() throws TornadoExecutionPlanException {
        FloatArray a = new FloatArray(4096);
        FloatArray b = new FloatArray(4096);
        FloatArray c = new FloatArray(4096);
        IntStream.range(0, 4096).sequential().forEach(i -> {
            a.set(i, (float)Math.random());
            b.set(i, (float)Math.random());
        });
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestArrays::vectorAddFloat, (Object)a, (Object)b, (Object)c).transferToHost(1, new Object[]{c});
        WorkerGrid1D worker = new WorkerGrid1D(4096);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
            worker.setGlobalWork(512L, 1L, 1L);
            executionPlan.execute();
        }
        for (int i2 = 0; i2 < 512; ++i2) {
            Assert.assertEquals((float)(a.get(i2) + b.get(i2)), (float)c.get(i2), (float)0.01f);
        }
    }

    @Test
    public void testDynamicGrid02() throws TornadoExecutionPlanException {
        int numElements = 256;
        FloatArray a = new FloatArray(65536);
        FloatArray b = new FloatArray(65536);
        FloatArray c = new FloatArray(65536);
        FloatArray seq = new FloatArray(65536);
        IntStream.range(0, 256).sequential().forEach(i -> {
            a.set(i, (float)Math.random());
            b.set(i, (float)Math.random());
        });
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t1", TestGrid::matrixMultiplication, (Object)a, (Object)b, (Object)c, (Object)256).transferToHost(1, new Object[]{c});
        WorkerGrid2D worker = new WorkerGrid2D(256, 256);
        GridScheduler gridScheduler = new GridScheduler("s0.t1", (WorkerGrid)worker);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executor.withGridScheduler(gridScheduler).execute();
            worker.setLocalWork(32L, 32L, 1L);
            executor.execute();
        }
        TestGrid.matrixMultiplication(a, b, seq, 256);
        for (int i2 = 0; i2 < 256; ++i2) {
            for (int j = 0; j < 256; ++j) {
                Assert.assertEquals((float)seq.get(i2 * 256 + j), (float)c.get(i2 * 256 + j), (float)0.1f);
            }
        }
    }

    @Test
    public void testDynamicGrid03() throws TornadoExecutionPlanException {
        TestGrid.testMatrixIntegers(256, 128);
    }

    @Test
    public void testDynamicGrid04() throws TornadoExecutionPlanException {
        FloatArray a = new FloatArray(4096);
        FloatArray b = new FloatArray(4096);
        FloatArray c = new FloatArray(4096);
        IntStream.range(0, 4096).sequential().forEach(i -> {
            a.set(i, (float)Math.random());
            b.set(i, (float)Math.random());
        });
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestArrays::vectorAddFloat, (Object)a, (Object)b, (Object)c).task("t1", TestArrays::vectorAddFloat, (Object)a, (Object)b, (Object)c).transferToHost(1, new Object[]{c});
        WorkerGrid1D worker = new WorkerGrid1D(4096);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        gridScheduler.addWorkerGrid("s0.t1", (WorkerGrid)worker);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executor.withGridScheduler(gridScheduler).execute();
            worker.setGlobalWork(512L, 1L, 1L);
            executor.execute();
        }
        for (int i2 = 0; i2 < 512; ++i2) {
            Assert.assertEquals((float)(a.get(i2) + b.get(i2)), (float)c.get(i2), (float)0.01f);
        }
    }

    @Test
    public void testOutOfRangeDimensions() throws TornadoExecutionPlanException {
        int N = 512;
        FloatArray matrixA = new FloatArray(N * N);
        FloatArray matrixB = new FloatArray(N * N);
        FloatArray matrixC = new FloatArray(N * N);
        IntStream.range(0, N * N).parallel().forEach(idx -> {
            matrixA.set(idx, 2.5f);
            matrixB.set(idx, 3.5f);
        });
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA, matrixB}).task("mxm", TestGrid::matrixMultiplication, (Object)matrixA, (Object)matrixB, (Object)matrixC, (Object)N).transferToHost(1, new Object[]{matrixC});
        WorkerGrid2D worker = new WorkerGrid2D(N, N);
        GridScheduler gridScheduler = new GridScheduler("s0.mxm", (WorkerGrid)worker);
        worker.setGlobalWork((long)N, (long)N, 1L);
        worker.setLocalWork(256L, 256L, 1L);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executor.withGridScheduler(gridScheduler).execute();
        }
    }
}

