/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.examples.vectors;

import java.util.ArrayList;
import java.util.stream.IntStream;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.TornadoExecutionResult;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.collections.VectorFloat2;
import uk.ac.manchester.tornado.api.types.collections.VectorFloat4;
import uk.ac.manchester.tornado.api.types.collections.VectorFloat8;
import uk.ac.manchester.tornado.api.types.vectors.Float2;
import uk.ac.manchester.tornado.api.types.vectors.Float4;
import uk.ac.manchester.tornado.api.types.vectors.Float8;
import uk.ac.manchester.tornado.examples.utils.Utils;

public class VectorAddTest {
    public static final int WARMUP = 100;
    public static final int ITERATIONS = 100;

    private static void computeAdd(FloatArray a, FloatArray b, FloatArray results) {
        for (int i = 0; i < a.getSize(); ++i) {
            results.set(i, a.get(i) + b.get(i));
        }
    }

    private static void computeAddWithVectors2(VectorFloat2 a, VectorFloat2 b, VectorFloat2 results) {
        for (int i = 0; i < a.getLength(); ++i) {
            results.set(i, Float2.add((Float2)a.get(i), (Float2)b.get(i)));
        }
    }

    private static void computeAddWithVectors4(VectorFloat4 a, VectorFloat4 b, VectorFloat4 results) {
        for (int i = 0; i < a.getLength(); ++i) {
            results.set(i, Float4.add((Float4)a.get(i), (Float4)b.get(i)));
        }
    }

    private static void computeAddWithVectors8(VectorFloat8 a, VectorFloat8 b, VectorFloat8 results) {
        for (int i = 0; i < a.getLength(); ++i) {
            results.set(i, Float8.add((Float8)a.get(i), (Float8)b.get(i)));
        }
    }

    private static void runWithVectorTypes4(int size, TornadoDevice device) {
        VectorFloat4 a = new VectorFloat4(size);
        VectorFloat4 b = new VectorFloat4(size);
        VectorFloat4 results = new VectorFloat4(size);
        for (int i = 0; i < a.getLength(); ++i) {
            a.set(i, new Float4((float)i, (float)i, (float)i, (float)i));
            b.set(i, new Float4((float)(2 * i), (float)(2 * i), (float)(2 * i), (float)(2 * i)));
        }
        TaskGraph taskGraph = new TaskGraph("compute").transferToDevice(0, new Object[]{a, b}).task("addWithVectors4", VectorAddTest::computeAddWithVectors4, (Object)a, (Object)b, (Object)results).transferToHost(1, new Object[]{results});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executorPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executorPlan.withDevice(device).withPreCompilation();
        ArrayList<Long> kernelTimers = new ArrayList<Long>();
        ArrayList<Long> totalTimers = new ArrayList<Long>();
        for (int i = 0; i < 100; ++i) {
            TornadoExecutionResult executionResult = executorPlan.execute();
            kernelTimers.add(executionResult.getProfilerResult().getDeviceKernelTime());
            totalTimers.add(executionResult.getProfilerResult().getTotalTime());
        }
        executorPlan.freeDeviceMemory();
        long[] kernelTimersLong = kernelTimers.stream().mapToLong(Long::longValue).toArray();
        long[] totalTimersLong = totalTimers.stream().mapToLong(Long::longValue).toArray();
        System.out.println("Stats KernelTime");
        Utils.computeStatistics(kernelTimersLong);
        System.out.println("Stats TotalTime");
        Utils.computeStatistics(totalTimersLong);
    }

    private static void runWithVectorTypes2(int size, TornadoDevice device) {
        VectorFloat2 a = new VectorFloat2(size *= 2);
        VectorFloat2 b = new VectorFloat2(size);
        VectorFloat2 results = new VectorFloat2(size);
        for (int i = 0; i < a.getLength(); ++i) {
            a.set(i, new Float2((float)i, (float)i));
            b.set(i, new Float2((float)(2 * i), (float)(2 * i)));
        }
        TaskGraph taskGraph = new TaskGraph("compute").transferToDevice(0, new Object[]{a, b}).task("addWithVectors2", VectorAddTest::computeAddWithVectors2, (Object)a, (Object)b, (Object)results).transferToHost(1, new Object[]{results});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executionPlan.withDevice(device).withPreCompilation();
        ArrayList<Long> kernelTimers = new ArrayList<Long>();
        ArrayList<Long> totalTimers = new ArrayList<Long>();
        for (int i = 0; i < 100; ++i) {
            TornadoExecutionResult executionResult = executionPlan.execute();
            kernelTimers.add(executionResult.getProfilerResult().getDeviceKernelTime());
            totalTimers.add(executionResult.getProfilerResult().getTotalTime());
        }
        executionPlan.freeDeviceMemory();
        long[] kernelTimersLong = kernelTimers.stream().mapToLong(Long::longValue).toArray();
        long[] totalTimersLong = totalTimers.stream().mapToLong(Long::longValue).toArray();
        System.out.println("Stats KernelTime");
        Utils.computeStatistics(kernelTimersLong);
        System.out.println("Stats TotalTime");
        Utils.computeStatistics(totalTimersLong);
    }

    private static void runWithVectorTypes8(int size, TornadoDevice device) {
        VectorFloat8 a = new VectorFloat8(size /= 2);
        VectorFloat8 b = new VectorFloat8(size);
        VectorFloat8 results = new VectorFloat8(size);
        for (int i = 0; i < a.getLength(); ++i) {
            a.set(i, new Float8((float)i, (float)i, (float)i, (float)i, (float)i, (float)i, (float)i, (float)i));
            b.set(i, new Float8((float)(2 * i), (float)(2 * i), (float)(2 * i), (float)(2 * i), (float)(2 * i), (float)(2 * i), (float)(2 * i), (float)(2 * i)));
        }
        TaskGraph taskGraph = new TaskGraph("compute").transferToDevice(0, new Object[]{a, b}).task("addWithVectors8", VectorAddTest::computeAddWithVectors8, (Object)a, (Object)b, (Object)results).transferToHost(1, new Object[]{results});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        TornadoExecutionPlan executorPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});
        executorPlan.withDevice(device).withPreCompilation();
        ArrayList<Long> kernelTimers = new ArrayList<Long>();
        ArrayList<Long> totalTimers = new ArrayList<Long>();
        for (int i = 0; i < 100; ++i) {
            TornadoExecutionResult executionResult = executorPlan.execute();
            kernelTimers.add(executionResult.getProfilerResult().getDeviceKernelTime());
            totalTimers.add(executionResult.getProfilerResult().getTotalTime());
        }
        executorPlan.freeDeviceMemory();
        long[] kernelTimersLong = kernelTimers.stream().mapToLong(Long::longValue).toArray();
        long[] totalTimersLong = totalTimers.stream().mapToLong(Long::longValue).toArray();
        System.out.println("Stats KernelTime");
        Utils.computeStatistics(kernelTimersLong);
        System.out.println("Stats TotalTime");
        Utils.computeStatistics(totalTimersLong);
    }

    private static void runWithoutVectorTypes(int size, TornadoDevice device) {
        FloatArray af = new FloatArray(size * 4);
        FloatArray bf = new FloatArray(size * 4);
        FloatArray rf = new FloatArray(size * 4);
        for (int i = 0; i < af.getSize(); ++i) {
            af.set(i, (float)i);
            bf.set(i, 2.0f * (float)i);
        }
        TaskGraph taskGraphNonVector = new TaskGraph("nonVector").transferToDevice(0, new Object[]{af, bf}).task("computeWithPrimitiveArray", VectorAddTest::computeAdd, (Object)af, (Object)bf, (Object)rf).transferToHost(1, new Object[]{rf});
        ImmutableTaskGraph immutableTaskGraph2 = taskGraphNonVector.snapshot();
        TornadoExecutionPlan executorPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph2});
        executorPlan.withDevice(device).withPreCompilation();
        for (int i = 0; i < 100; ++i) {
            executorPlan.execute();
        }
        ArrayList<Long> kernelTimers = new ArrayList<Long>();
        ArrayList<Long> totalTimers = new ArrayList<Long>();
        for (int i = 0; i < 100; ++i) {
            TornadoExecutionResult executionResult = executorPlan.execute();
            kernelTimers.add(executionResult.getProfilerResult().getDeviceKernelTime());
            totalTimers.add(executionResult.getProfilerResult().getTotalTime());
        }
        executorPlan.freeDeviceMemory();
        long[] kernelTimersLong = kernelTimers.stream().mapToLong(Long::longValue).toArray();
        long[] totalTimersLong = totalTimers.stream().mapToLong(Long::longValue).toArray();
        System.out.println("Stats KernelTime");
        Utils.computeStatistics(kernelTimersLong);
        System.out.println("Stats TotalTime");
        Utils.computeStatistics(totalTimersLong);
    }

    private static void computeWithStreams(int size, FloatArray a, FloatArray b, FloatArray results) {
        IntStream.range(0, size).parallel().forEach(i -> results.set(i, a.get(i) + b.get(i)));
    }

    private static void runWithJavaStreams(int size) {
        int i;
        FloatArray a = new FloatArray(size *= 4);
        FloatArray b = new FloatArray(size);
        FloatArray results = new FloatArray(size);
        for (i = 0; i < a.getSize(); ++i) {
            a.set(i, (float)i);
            b.set(i, 2.0f * (float)i);
        }
        for (i = 0; i < 100; ++i) {
            VectorAddTest.computeWithStreams(size, a, b, results);
        }
        ArrayList<Long> kernelTimersVectors = new ArrayList<Long>();
        for (int i2 = 0; i2 < 100; ++i2) {
            long start = System.nanoTime();
            VectorAddTest.computeWithStreams(size, a, b, results);
            long end = System.nanoTime();
            kernelTimersVectors.add(end - start);
        }
        long[] kernelTimersVectorsLong = kernelTimersVectors.stream().mapToLong(Long::longValue).toArray();
        System.out.println("Stats");
        Utils.computeStatistics(kernelTimersVectorsLong);
    }

    public static void main(String[] args) {
        String version = "vector";
        if (args.length > 0) {
            try {
                version = args[0];
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        int size = 0x1000000;
        if (args.length > 1) {
            try {
                size = Integer.parseInt(args[1]);
            }
            catch (NumberFormatException numberFormatException) {
                // empty catch block
            }
        }
        TornadoDevice device = TornadoExecutionPlan.getDevice((int)0, (int)0);
        if (version.startsWith("vector4")) {
            VectorAddTest.runWithVectorTypes4(size, device);
        } else if (version.startsWith("vector2")) {
            VectorAddTest.runWithVectorTypes2(size, device);
        } else if (version.startsWith("vector8")) {
            VectorAddTest.runWithVectorTypes8(size, device);
        } else if (version.startsWith("stream")) {
            VectorAddTest.runWithJavaStreams(size);
        } else if (version.startsWith("plain")) {
            VectorAddTest.runWithoutVectorTypes(size, device);
        } else {
            throw new RuntimeException("Option not found");
        }
    }
}

