/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.examples.kernelcontext.compute;

import java.util.stream.IntStream;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid2D;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

public class MatrixMultiplication2DV1 {
    private static final int WARMING_UP_ITERATIONS = 15;

    public static void matrixMultiplication(KernelContext context, FloatArray A, FloatArray B, FloatArray C, int size) {
        int globalRow = context.globalIdx;
        int globalCol = context.globalIdy;
        float sum = 0.0f;
        for (int k = 0; k < size; ++k) {
            sum += A.get(k * size + globalRow) * B.get(globalCol * size + k);
        }
        C.set(globalCol * size + globalRow, sum);
    }

    private static void matrixMultiplication(FloatArray A, FloatArray B, FloatArray C, int size) {
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                float sum = 0.0f;
                for (int k = 0; k < size; ++k) {
                    sum += A.get(i * size + k) * B.get(k * size + j);
                }
                C.set(i * size + j, sum);
            }
        }
    }

    public static void main(String[] args) {
        int size = 512;
        if (args.length >= 1) {
            try {
                size = Integer.parseInt(args[0]);
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        System.out.println("Computing MxM of " + size + "x" + size);
        FloatArray matrixA = new FloatArray(size * size);
        FloatArray matrixB = new FloatArray(size * size);
        FloatArray matrixC = new FloatArray(size * size);
        FloatArray resultSeq = new FloatArray(size * size);
        IntStream.range(0, size * size).parallel().forEach(idx -> {
            matrixA.set(idx, 2.5f);
            matrixB.set(idx, 3.5f);
        });
        WorkerGrid2D workerGrid = new WorkerGrid2D(size, size);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        KernelContext context = new KernelContext();
        workerGrid.setLocalWork(16L, 16L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA, matrixB}).task("t0", MatrixMultiplication2DV1::matrixMultiplication, (Object)context, (Object)matrixA, (Object)matrixB, (Object)matrixC, (Object)size).transferToHost(1, new Object[]{matrixC});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executor.withGridScheduler(gridScheduler);
        for (int i = 0; i < 15; ++i) {
            executor.execute();
        }
        long start = System.currentTimeMillis();
        executor.execute();
        long end = System.currentTimeMillis();
        for (int i = 0; i < 15; ++i) {
            MatrixMultiplication2DV1.matrixMultiplication(matrixA, matrixB, resultSeq, size);
        }
        long startSequential = System.currentTimeMillis();
        MatrixMultiplication2DV1.matrixMultiplication(matrixA, matrixB, resultSeq, size);
        long endSequential = System.currentTimeMillis();
        long msecTornadoVMElapsedTime = end - start;
        long msecSequentialElaptedTime = endSequential - startSequential;
        double flops = 2.0 * Math.pow(size, 3.0);
        double tornadoVMGigaFlops = 1.0E-9 * flops / (double)((float)msecTornadoVMElapsedTime / 1000.0f);
        double sequentialGigaFlops = 1.0E-9 * flops / (double)((float)msecSequentialElaptedTime / 1000.0f);
        double speedup = (double)(endSequential - startSequential) / (double)(end - start);
        String formatTornadoVMGFlops = String.format("%.2f", tornadoVMGigaFlops);
        String formatSequentialGFlops = String.format("%.2f", sequentialGigaFlops);
        System.out.println("\tSequential Execution: " + formatSequentialGFlops + " GFlops, Total time = " + (endSequential - startSequential) + " ms");
        System.out.println("\tTornadoVM Execution: " + formatTornadoVMGFlops + " GFlops, Total Time = " + (end - start) + " ms");
        System.out.println("\tSpeedup: " + speedup + "x");
        System.out.println("\tVerification " + MatrixMultiplication2DV1.verify(matrixC, resultSeq, size));
    }

    private static boolean verify(FloatArray par, FloatArray seq, int size) {
        boolean check = true;
        for (int i = 0; i < size * size; ++i) {
            if (!(Math.abs(par.get(i) - seq.get(i)) > 0.01f)) continue;
            check = false;
            break;
        }
        return check;
    }
}

