/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.examples.vectors;

import java.util.ArrayList;
import java.util.Random;
import java.util.stream.IntStream;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.TornadoExecutionResult;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.types.collections.VectorFloat;
import uk.ac.manchester.tornado.api.types.collections.VectorFloat4;
import uk.ac.manchester.tornado.api.types.matrix.Matrix2DFloat;
import uk.ac.manchester.tornado.api.types.matrix.Matrix2DFloat4;
import uk.ac.manchester.tornado.api.types.vectors.Float4;
import uk.ac.manchester.tornado.examples.utils.Utils;

public class MatrixVector {
    public static final int WARMUP = 100;
    public static final int ITERATIONS = 100;

    private static void compute(Matrix2DFloat matrix, VectorFloat vector, VectorFloat output) {
        for (int i = 0; i < matrix.getNumRows(); ++i) {
            float sum = 0.0f;
            for (int j = 0; j < matrix.getNumColumns(); ++j) {
                sum += matrix.get(i, j) * vector.get(j);
            }
            output.set(i, sum);
        }
    }

    private static void computeWithVectors(Matrix2DFloat4 matrix, VectorFloat4 vector, VectorFloat output) {
        for (int i = 0; i < matrix.getNumRows(); ++i) {
            float sum = 0.0f;
            for (int j = 0; j < matrix.getNumColumns(); ++j) {
                sum += Float4.sum((Float4)Float4.mult((Float4)matrix.get(i, j), (Float4)vector.get(j)));
            }
            output.set(i, sum);
        }
    }

    private static void runWithVectorTypes(int size, TornadoDevice device) {
        Matrix2DFloat4 matrix2DFloat = new Matrix2DFloat4(size, size);
        VectorFloat4 vectorFloat = new VectorFloat4(size);
        VectorFloat result = new VectorFloat(size);
        Random r = new Random();
        int s = size;
        IntStream.range(0, size).forEach(idx -> vectorFloat.set(idx, new Float4(r.nextFloat(), r.nextFloat(), r.nextFloat(), r.nextFloat())));
        IntStream.range(0, size).forEach(idx -> IntStream.range(0, s).forEach(jdx -> matrix2DFloat.set(idx, jdx, new Float4(r.nextFloat(), r.nextFloat(), r.nextFloat(), r.nextFloat()))));
        TaskGraph taskGraph = new TaskGraph("computeVectors").transferToDevice(0, new Object[]{vectorFloat, matrix2DFloat}).task("witVectors", MatrixVector::computeWithVectors, (Object)matrix2DFloat, (Object)vectorFloat, (Object)result).transferToHost(1, new Object[]{result});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executionPlan.withDevice(device).withPreCompilation();
        for (int i = 0; i < 100; ++i) {
            executionPlan.execute();
        }
        ArrayList<Long> kernelTimers = new ArrayList<Long>();
        ArrayList<Long> totalTimers = new ArrayList<Long>();
        for (int i = 0; i < 100; ++i) {
            TornadoExecutionResult executionResult = executionPlan.execute();
            kernelTimers.add(executionResult.getProfilerResult().getDeviceKernelTime());
            totalTimers.add(executionResult.getProfilerResult().getTotalTime());
        }
        executionPlan.freeDeviceMemory();
        long[] kernelTimersLong = kernelTimers.stream().mapToLong(Long::longValue).toArray();
        long[] totalTimersLong = totalTimers.stream().mapToLong(Long::longValue).toArray();
        System.out.println("Stats KernelTime");
        Utils.computeStatistics(kernelTimersLong);
        System.out.println("Stats TotalTime");
        Utils.computeStatistics(totalTimersLong);
    }

    private static void runWithoutVectorTypes(int size, TornadoDevice device) {
        int s = size * 4;
        Matrix2DFloat matrix2DFloat = new Matrix2DFloat(size *= 4, size);
        VectorFloat vectorFloat = new VectorFloat(size);
        VectorFloat result = new VectorFloat(size);
        Random r = new Random();
        IntStream.range(0, size).forEach(idx -> vectorFloat.set(idx, r.nextFloat()));
        IntStream.range(0, size).forEach(idx -> IntStream.range(0, s).forEach(jdx -> matrix2DFloat.set(idx, jdx, r.nextFloat())));
        TaskGraph taskGraph = new TaskGraph("compute").transferToDevice(0, new Object[]{vectorFloat, matrix2DFloat}).task("noVectors", MatrixVector::compute, (Object)matrix2DFloat, (Object)vectorFloat, (Object)result).transferToHost(1, new Object[]{result});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executionPlan.withDevice(device).withPreCompilation();
        for (int i = 0; i < 100; ++i) {
            executionPlan.execute();
        }
        ArrayList<Long> kernelTimers = new ArrayList<Long>();
        ArrayList<Long> totalTimers = new ArrayList<Long>();
        for (int i = 0; i < 100; ++i) {
            TornadoExecutionResult executionResult = executionPlan.execute();
            kernelTimers.add(executionResult.getProfilerResult().getDeviceKernelTime());
            totalTimers.add(executionResult.getProfilerResult().getTotalTime());
        }
        executionPlan.freeDeviceMemory();
        long[] kernelTimersLong = kernelTimers.stream().mapToLong(Long::longValue).toArray();
        long[] totalTimersLong = totalTimers.stream().mapToLong(Long::longValue).toArray();
        System.out.println("Stats KernelTime");
        Utils.computeStatistics(kernelTimersLong);
        System.out.println("Stats TotalTime");
        Utils.computeStatistics(totalTimersLong);
    }

    private static void computeWithStreams(int size, Matrix2DFloat matrix, VectorFloat vector, VectorFloat output) {
        IntStream.range(0, size).parallel().forEach(i -> {
            float sum = 0.0f;
            for (int j = 0; j < matrix.getNumColumns(); ++j) {
                sum += matrix.get(i, j) * vector.get(j);
            }
            output.set(i, sum);
        });
    }

    private static void runWithJavaStreams(int size) {
        Matrix2DFloat matrix2DFloat = new Matrix2DFloat(size *= 4, size);
        VectorFloat vectorFloat = new VectorFloat(size);
        VectorFloat result = new VectorFloat(size);
        Random r = new Random();
        IntStream.range(0, size).forEach(idx -> vectorFloat.set(idx, r.nextFloat()));
        int s = size;
        IntStream.range(0, size).forEach(idx -> IntStream.range(0, s).forEach(jdx -> matrix2DFloat.set(idx, jdx, r.nextFloat())));
        for (int i = 0; i < 100; ++i) {
            MatrixVector.computeWithStreams(size, matrix2DFloat, vectorFloat, result);
        }
        ArrayList<Long> kernelTimersVectors = new ArrayList<Long>();
        for (int i = 0; i < 100; ++i) {
            long start = System.nanoTime();
            MatrixVector.computeWithStreams(size, matrix2DFloat, vectorFloat, result);
            long end = System.nanoTime();
            kernelTimersVectors.add(end - start);
        }
        long[] kernelTimersVectorsLong = kernelTimersVectors.stream().mapToLong(Long::longValue).toArray();
        System.out.println("Stats");
        Utils.computeStatistics(kernelTimersVectorsLong);
    }

    public static void main(String[] args) {
        String version = "vector";
        if (args.length > 0) {
            try {
                version = args[0];
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        int size = 2048;
        if (args.length > 1) {
            try {
                size = Integer.parseInt(args[1]);
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        TornadoDevice device = TornadoExecutionPlan.getDevice((int)0, (int)0);
        if (version.startsWith("vector")) {
            MatrixVector.runWithVectorTypes(size, device);
        } else if (version.startsWith("stream")) {
            MatrixVector.runWithJavaStreams(size);
        } else {
            MatrixVector.runWithoutVectorTypes(size, device);
        }
    }
}

