/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.loops;

import java.util.Random;
import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoBackend;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestLoopTransformations
extends TornadoTestBase {
    private static void matrixVectorMultiplication(FloatArray A2, FloatArray B2, FloatArray C, int size) {
        for (int i = 0; i < size; ++i) {
            float sum = 0.0f;
            for (int j = 0; j < size; ++j) {
                sum += A2.get(i * size + j) * B2.get(j);
            }
            C.set(i, sum);
        }
    }

    private static void matrixTranspose(FloatArray A2, FloatArray B2, int size) {
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                B2.set(i * size + j, A2.get(j * size + i));
            }
        }
    }

    @Test
    public void testPartialUnrollDefault() throws TornadoExecutionPlanException {
        int size = 512;
        FloatArray matrixA = new FloatArray(size * size);
        FloatArray matrixB = new FloatArray(size * size);
        FloatArray matrixC = new FloatArray(size * size);
        FloatArray resultSeq = new FloatArray(size * size);
        Random r = new Random();
        IntStream.range(0, size * size).parallel().forEach(idx -> matrixA.set(idx, r.nextFloat()));
        IntStream.range(0, size).parallel().forEach(idx -> matrixB.set(idx, r.nextFloat()));
        TornadoRuntimeProvider.setProperty((String)"tornado.experimental.partial.unroll", (String)"True");
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA, matrixB}).task("t0", TestLoopTransformations::matrixVectorMultiplication, (Object)matrixA, (Object)matrixB, (Object)matrixC, (Object)size).transferToHost(1, new Object[]{matrixC});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        TestLoopTransformations.matrixVectorMultiplication(matrixA, matrixB, resultSeq, size);
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                Assert.assertEquals((float)matrixC.get(i * size + j), (float)resultSeq.get(i * size + j), (float)0.01f);
            }
        }
    }

    @Test
    public void testPartialUnrollNvidia32() throws TornadoExecutionPlanException {
        int size = 512;
        FloatArray matrixA = new FloatArray(size * size);
        FloatArray matrixB = new FloatArray(size * size);
        FloatArray matrixC = new FloatArray(size * size);
        FloatArray resultSeq = new FloatArray(size * size);
        Random r = new Random();
        IntStream.range(0, size * size).parallel().forEach(idx -> matrixA.set(idx, r.nextFloat()));
        IntStream.range(0, size).parallel().forEach(idx -> matrixB.set(idx, r.nextFloat()));
        TornadoRuntimeProvider.setProperty((String)"tornado.experimental.partial.unroll", (String)"True");
        for (int i = 0; i < TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getNumDevices(); ++i) {
            if (!TornadoRuntimeProvider.getTornadoRuntime().getBackend(0).getDevice(i).getPlatformName().toLowerCase().contains("nvidia")) continue;
            TornadoBackend driver = TornadoRuntimeProvider.getTornadoRuntime().getBackend(0);
            driver.setDefaultDevice(i);
            TornadoRuntimeProvider.setProperty((String)"tornado.unroll.factor", (String)"32");
            System.setProperty("tornado.unroll.factor", "32");
        }
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA, matrixB}).task("t0", TestLoopTransformations::matrixVectorMultiplication, (Object)matrixA, (Object)matrixB, (Object)matrixC, (Object)size).transferToHost(1, new Object[]{matrixC});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        TestLoopTransformations.matrixVectorMultiplication(matrixA, matrixB, resultSeq, size);
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                Assert.assertEquals((float)matrixC.get(i * size + j), (float)resultSeq.get(i * size + j), (float)0.01f);
            }
        }
    }

    @Test
    public void testPartialUnrollParallelLoops() throws TornadoExecutionPlanException {
        int j;
        int i;
        int N = 256;
        FloatArray matrixA = new FloatArray(65536);
        FloatArray matrixB = new FloatArray(65536);
        FloatArray resultSeq = new FloatArray(65536);
        TornadoRuntimeProvider.setProperty((String)"tornado.experimental.partial.unroll", (String)"True");
        Random r = new Random();
        IntStream.range(0, 65536).parallel().forEach(idx -> {
            matrixA.set(idx, r.nextFloat());
            matrixB.set(idx, r.nextFloat());
        });
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA}).task("t0", TestLoopTransformations::matrixTranspose, (Object)matrixA, (Object)matrixB, (Object)256).transferToHost(1, new Object[]{matrixB});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (i = 0; i < 256; ++i) {
            for (j = 0; j < 256; ++j) {
                resultSeq.set(i * 256 + j, matrixA.get(j * 256 + i));
            }
        }
        for (i = 0; i < 256; ++i) {
            for (j = 0; j < 256; ++j) {
                Assert.assertEquals((double)resultSeq.get(i * 256 + j), (double)matrixB.get(i * 256 + j), (double)0.1);
            }
        }
    }
}

