/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.examples.kernelcontext.compute;

import java.util.Random;
import java.util.stream.IntStream;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

public class MatrixMultiplication1D {
    private static final int WARMING_UP_ITERATIONS = 15;

    public static void matrixMultiplication(KernelContext context, FloatArray A, FloatArray B, FloatArray C, int size) {
        int idx = context.globalIdx;
        for (int jdx = 0; jdx < size; ++jdx) {
            float sum = 0.0f;
            for (int k = 0; k < size; ++k) {
                sum += A.get(idx * size + k) * B.get(k * size + jdx);
            }
            C.set(idx * size + jdx, sum);
        }
    }

    private static void matrixMultiplication(FloatArray A, FloatArray B, FloatArray C, int size) {
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                float sum = 0.0f;
                for (int k = 0; k < size; ++k) {
                    sum += A.get(i * size + k) * B.get(k * size + j);
                }
                C.set(i * size + j, sum);
            }
        }
    }

    public static void main(String[] args) {
        int size = 512;
        if (args.length >= 1) {
            try {
                size = Integer.parseInt(args[0]);
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        System.out.println("Computing MxM of " + size + "x" + size);
        FloatArray matrixA = new FloatArray(size * size);
        FloatArray matrixB = new FloatArray(size * size);
        FloatArray matrixC = new FloatArray(size * size);
        FloatArray resultSeq = new FloatArray(size * size);
        Random r = new Random();
        IntStream.range(0, size * size).parallel().forEach(idx -> {
            matrixA.set(idx, r.nextFloat());
            matrixB.set(idx, r.nextFloat());
        });
        WorkerGrid1D workerGrid = new WorkerGrid1D(size);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        KernelContext context = new KernelContext();
        workerGrid.setGlobalWork((long)size, 1L, 1L);
        workerGrid.setLocalWork((long)(size <= 1024 ? size : size / 2), 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA, matrixB}).task("t0", MatrixMultiplication1D::matrixMultiplication, (Object)context, (Object)matrixA, (Object)matrixB, (Object)matrixC, (Object)size).transferToHost(1, new Object[]{matrixC});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executor.withGridScheduler(gridScheduler);
        for (int i = 0; i < 15; ++i) {
            executor.execute();
        }
        long start = System.currentTimeMillis();
        executor.execute();
        long end = System.currentTimeMillis();
        for (int i = 0; i < 15; ++i) {
            MatrixMultiplication1D.matrixMultiplication(matrixA, matrixB, resultSeq, size);
        }
        long startSequential = System.currentTimeMillis();
        MatrixMultiplication1D.matrixMultiplication(matrixA, matrixB, resultSeq, size);
        long endSequential = System.currentTimeMillis();
        long msecTornadoVMElapsedTime = end - start;
        long msecSequentialElaptedTime = endSequential - startSequential;
        double flops = 2.0 * Math.pow(size, 3.0);
        double tornadoVMGigaFlops = 1.0E-9 * flops / (double)((float)msecTornadoVMElapsedTime / 1000.0f);
        double sequentialGigaFlops = 1.0E-9 * flops / (double)((float)msecSequentialElaptedTime / 1000.0f);
        double speedup = (double)(endSequential - startSequential) / (double)(end - start);
        String formatTornadoVMGFlops = String.format("%.2f", tornadoVMGigaFlops);
        String formatSequentialGFlops = String.format("%.2f", sequentialGigaFlops);
        System.out.println("\tSequential Execution: " + formatSequentialGFlops + " GFlops, Total time = " + (endSequential - startSequential) + " ms");
        System.out.println("\tTornadoVM Execution: " + formatTornadoVMGFlops + " GFlops, Total Time = " + (end - start) + " ms");
        System.out.println("\tSpeedup: " + speedup + "x");
        System.out.println("\tVerification " + MatrixMultiplication1D.verify(matrixC, resultSeq, size));
    }

    private static boolean verify(FloatArray par, FloatArray seq, int size) {
        boolean check = true;
        block0: for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                if (!(Math.abs(par.get(i * size + j) - seq.get(i * size + j)) > 0.1f)) continue;
                check = false;
                continue block0;
            }
        }
        return check;
    }
}

