/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.kernelcontext.matrices;

import java.util.Random;
import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.WorkerGrid2D;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestMatrixMultiplicationKernelContext
extends TornadoTestBase {
    private static final int TS = 4;

    public static void matrixMultiplicationJava(FloatArray a, FloatArray b, FloatArray c, int size) {
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                float sum = 0.0f;
                for (int k = 0; k < size; ++k) {
                    sum += a.get(i * size + k) * b.get(k * size + j);
                }
                c.set(i * size + j, sum);
            }
        }
    }

    public static void matrixMultiplication1D(KernelContext context, FloatArray a, FloatArray b, FloatArray c, int size) {
        int idx = context.globalIdx;
        for (int jdx = 0; jdx < size; ++jdx) {
            float sum = 0.0f;
            for (int k = 0; k < size; ++k) {
                sum += a.get(idx * size + k) * b.get(k * size + jdx);
            }
            c.set(idx * size + jdx, sum);
        }
    }

    public static void matrixMultiplication2D01(KernelContext context, FloatArray a, FloatArray b, FloatArray c, int size) {
        int idx = context.globalIdx;
        int jdx = context.globalIdy;
        float sum = 0.0f;
        for (int k = 0; k < size; ++k) {
            sum += a.get(k * size + idx) * b.get(jdx * size + k);
        }
        c.set(idx * size + jdx, sum);
    }

    public static void matrixMultiplication2D02(KernelContext context, FloatArray A2, FloatArray B2, FloatArray C, int size) {
        int row = context.localIdx;
        int col = context.localIdy;
        int globalRow = 4 * context.groupIdx + row;
        int globalCol = 4 * context.groupIdy + col;
        float[] aSub = context.allocateFloatLocalArray(16);
        float[] bSub = context.allocateFloatLocalArray(16);
        float sum = 0.0f;
        int numTiles = size / 4;
        for (int tileIndex = 0; tileIndex < numTiles; ++tileIndex) {
            int tiledRow = 4 * tileIndex + row;
            int tiledCol = 4 * tileIndex + col;
            aSub[col * 4 + row] = A2.get(tiledCol * size + globalRow);
            bSub[col * 4 + row] = B2.get(globalCol * size + tiledRow);
            context.localBarrier();
            for (int k = 0; k < 4; ++k) {
                sum += aSub[k * 4 + row] * bSub[col * 4 + k];
            }
            context.localBarrier();
        }
        C.set(globalCol * size + globalRow, sum);
    }

    @Test
    public void mxm1DKernelContext() throws TornadoExecutionPlanException {
        int size = 16;
        FloatArray a = new FloatArray(256);
        FloatArray b = new FloatArray(256);
        FloatArray cJava = new FloatArray(256);
        FloatArray cTornado = new FloatArray(256);
        Random r = new Random();
        IntStream.range(0, 256).forEach(i -> {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
        });
        WorkerGrid1D worker = new WorkerGrid1D(16);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestMatrixMultiplicationKernelContext::matrixMultiplication1D, (Object)context, (Object)a, (Object)b, (Object)cTornado, (Object)16).transferToHost(1, new Object[]{cTornado});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestMatrixMultiplicationKernelContext.matrixMultiplicationJava(a, b, cJava, 16);
        for (int i2 = 0; i2 < 256; ++i2) {
            Assert.assertEquals((float)cJava.get(i2), (float)cTornado.get(i2), (float)0.01f);
        }
    }

    @Test
    public void mxm2DKernelContext01() throws TornadoExecutionPlanException {
        int size = 16;
        FloatArray a = new FloatArray(256);
        FloatArray b = new FloatArray(256);
        FloatArray cJava = new FloatArray(256);
        FloatArray cTornado = new FloatArray(256);
        Random r = new Random();
        IntStream.range(0, 256).forEach(i -> {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
        });
        WorkerGrid2D worker = new WorkerGrid2D(16, 16);
        GridScheduler gridScheduler = new GridScheduler();
        gridScheduler.addWorkerGrid("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestMatrixMultiplicationKernelContext::matrixMultiplication2D01, (Object)context, (Object)a, (Object)b, (Object)cTornado, (Object)16).transferToHost(1, new Object[]{cTornado});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestMatrixMultiplicationKernelContext.matrixMultiplicationJava(a, b, cJava, 16);
        for (int i2 = 0; i2 < 256; ++i2) {
            Assert.assertEquals((float)cJava.get(i2), (float)cTornado.get(i2), (float)0.01f);
        }
    }

    @Test
    public void mxm2DKernelContext02() throws TornadoExecutionPlanException {
        int size = 16;
        FloatArray a = new FloatArray(256);
        FloatArray b = new FloatArray(256);
        FloatArray cJava = new FloatArray(256);
        FloatArray cTornado = new FloatArray(256);
        Random r = new Random();
        IntStream.range(0, 256).forEach(i -> {
            a.set(i, r.nextFloat());
            b.set(i, r.nextFloat());
        });
        WorkerGrid2D worker = new WorkerGrid2D(16, 16);
        GridScheduler gridScheduler = new GridScheduler();
        gridScheduler.addWorkerGrid("s0.t0", (WorkerGrid)worker);
        KernelContext context = new KernelContext();
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a, b}).task("t0", TestMatrixMultiplicationKernelContext::matrixMultiplication2D02, (Object)context, (Object)a, (Object)b, (Object)cTornado, (Object)16).transferToHost(1, new Object[]{cTornado});
        worker.setLocalWork(4L, 4L, 1L);
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        TestMatrixMultiplicationKernelContext.matrixMultiplicationJava(a, b, cJava, 16);
        for (int i2 = 0; i2 < 256; ++i2) {
            Assert.assertEquals((float)cJava.get(i2), (float)cTornado.get(i2), (float)0.1f);
        }
    }
}

