/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.examples.compute;

import java.io.FileNotFoundException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Random;
import java.util.stream.IntStream;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.enums.TornadoDeviceType;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.matrix.Matrix2DFloat;

public class MatrixMultiplication2D {
    private static final int WARMING_UP_ITERATIONS = 100;

    private static void matrixMultiplication(Matrix2DFloat A, Matrix2DFloat B, Matrix2DFloat C, int size) {
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                float sum = 0.0f;
                for (int k = 0; k < size; ++k) {
                    sum += A.get(i, k) * B.get(k, j);
                }
                C.set(i, j, sum);
            }
        }
    }

    private static void parallelStreamsMxM(Matrix2DFloat A, Matrix2DFloat B, Matrix2DFloat C, int size) {
        IntStream.range(0, size).parallel().forEach(i -> IntStream.range(0, size).parallel().forEach(j -> {
            float sum = 0.0f;
            for (int k = 0; k < size; ++k) {
                sum += A.get(i, k) * B.get(k, j);
            }
            C.set(i, j, sum);
        }));
    }

    public static void main(String[] args) throws TornadoExecutionPlanException, FileNotFoundException {
        TornadoDeviceType deviceType;
        long end;
        long start;
        int size = 512;
        if (args.length >= 1) {
            size = Integer.parseInt(args[0]);
        }
        boolean verify = true;
        if (args.length >= 2) {
            verify = Boolean.parseBoolean(args[1]);
        }
        System.out.println("Computing MxM of " + size + "x" + size);
        Matrix2DFloat matrixA = new Matrix2DFloat(size, size);
        Matrix2DFloat matrixB = new Matrix2DFloat(size, size);
        Matrix2DFloat matrixC = new Matrix2DFloat(size, size);
        Matrix2DFloat resultSeq = new Matrix2DFloat(size, size);
        Random r = new Random();
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                matrixA.set(i, j, r.nextFloat());
                matrixB.set(i, j, r.nextFloat());
            }
        }
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{matrixA, matrixB}).task("t0", MatrixMultiplication2D::matrixMultiplication, (Object)matrixA, (Object)matrixB, (Object)matrixC, (Object)size).transferToHost(1, new Object[]{matrixC});
        ArrayList<Long> tornadoElapsedTime = new ArrayList<Long>();
        ArrayList<Long> javaElapsedTime = new ArrayList<Long>();
        ArrayList<Long> streamsElapsedTime = new ArrayList<Long>();
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executor = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executor.withPreCompilation();
            for (int i = 0; i < 100; ++i) {
                long s = System.nanoTime();
                executor.execute();
                long e = System.nanoTime();
                tornadoElapsedTime.add(e - s);
            }
            start = System.nanoTime();
            executor.execute();
            end = System.nanoTime();
            tornadoElapsedTime.add(end - start);
            deviceType = executor.getDevice(0).getDeviceType();
        }
        for (int i = 0; i < 100; ++i) {
            long s = System.nanoTime();
            MatrixMultiplication2D.matrixMultiplication(matrixA, matrixB, resultSeq, size);
            long e = System.nanoTime();
            javaElapsedTime.add(e - s);
        }
        long startSequential = System.nanoTime();
        MatrixMultiplication2D.matrixMultiplication(matrixA, matrixB, resultSeq, size);
        long endSequential = System.nanoTime();
        javaElapsedTime.add(endSequential - startSequential);
        for (int i = 0; i < 100; ++i) {
            long s = System.nanoTime();
            MatrixMultiplication2D.parallelStreamsMxM(matrixA, matrixB, resultSeq, size);
            long e = System.nanoTime();
            streamsElapsedTime.add(e - s);
        }
        long startStream = System.nanoTime();
        MatrixMultiplication2D.parallelStreamsMxM(matrixA, matrixB, resultSeq, size);
        long endStream = System.nanoTime();
        streamsElapsedTime.add(endStream - startStream);
        long nanoSecGPUElapsedTime = end - start;
        long nanoSecCPUElaptedTime = endSequential - startSequential;
        long nanoSecStreamElaptedTime = endStream - startStream;
        double flops = 2.0 * Math.pow(size, 3.0);
        float timeScaleSec = 1.0E9f;
        double gpuGigaFlops = 1.0E-9 * flops / (double)((float)nanoSecGPUElapsedTime / 1.0E9f);
        double cpuGigaFlops = 1.0E-9 * flops / (double)((float)nanoSecCPUElaptedTime / 1.0E9f);
        double streamGigaFlops = 1.0E-9 * flops / (double)((float)nanoSecStreamElaptedTime / 1.0E9f);
        double speedup = (double)(endSequential - startSequential) / (double)(end - start);
        String formatGPUFGlops = String.format("%.2f", gpuGigaFlops);
        String formatCPUFGlops = String.format("%.2f", cpuGigaFlops);
        String formatStreamFGlops = String.format("%.2f", streamGigaFlops);
        System.out.println("\tSingle Threaded CPU Execution: " + formatCPUFGlops + " GFlops, Total time = " + (endSequential - startSequential) + " ns");
        System.out.println("\tStreams Execution: " + formatStreamFGlops + " GFlops, Total time = " + nanoSecStreamElaptedTime + " ns");
        System.out.println("\tTornadoVM Execution on " + String.valueOf(deviceType) + " (Accelerated): " + formatGPUFGlops + " GFlops, Total Time = " + (end - start) + " ns");
        System.out.println("\tSpeedup: " + speedup + "x");
        if (verify) {
            System.out.println("\tVerification " + MatrixMultiplication2D.verify(matrixC, resultSeq, size));
        }
        PrintWriter fileWriter = new PrintWriter("stats-mxm-" + size + ".txt");
        fileWriter.println("Java, Stream, TornadoVM");
        for (int i = 0; i < javaElapsedTime.size(); ++i) {
            fileWriter.println(String.valueOf(javaElapsedTime.get(i)) + "," + String.valueOf(streamsElapsedTime.get(i)) + "," + String.valueOf(tornadoElapsedTime.get(i)));
        }
        fileWriter.close();
    }

    private static boolean verify(Matrix2DFloat par, Matrix2DFloat seq, int size) {
        boolean check = true;
        block0: for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                if (!(Math.abs(par.get(i, j) - seq.get(i, j)) > 0.1f)) continue;
                check = false;
                continue block0;
            }
        }
        return check;
    }
}

