/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.vector.api;

import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import java.util.Random;
import java.util.stream.IntStream;
import jdk.incubator.vector.FloatVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorSpecies;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestVectorAPI
extends TornadoTestBase {
    private static final int SIZE = 2048;
    private static final float DELTA = 0.001f;
    private static final Random rand = new Random();
    private static FloatArray arrayA;
    private static FloatArray arrayB;
    private static FloatArray referenceResult;

    public static float randFloat(float min, float max, Random rand) {
        return rand.nextFloat() * (max - min) + min;
    }

    @BeforeClass
    public static void setUpBeforeClass() {
        arrayA = new FloatArray(2048);
        arrayB = new FloatArray(2048);
        referenceResult = new FloatArray(2048);
        referenceResult.init(0.0f);
        for (int i = 0; i < arrayA.getSize(); ++i) {
            arrayA.set(i, TestVectorAPI.randFloat(0.0f, 2.0f, rand));
            arrayB.set(i, TestVectorAPI.randFloat(0.0f, 3.0f, rand));
        }
        referenceResult = TestVectorAPI.vectorAdditionFloatArray(arrayA, arrayB);
    }

    private static FloatArray vectorAdditionFloatArray(FloatArray a, FloatArray b) {
        FloatArray res = new FloatArray(2048);
        for (int i = 0; i < a.getSize(); ++i) {
            res.set(i, a.get(i) + b.get(i));
        }
        return res;
    }

    private float[] parallelVectorAdd(FloatArray vector1, FloatArray vector2, VectorSpecies<Float> species) {
        float[] result = new float[2048];
        System.out.println(species.toString());
        int width = vector1.getSize() / species.length();
        IntStream.range(0, width).parallel().forEach(i -> {
            long offsetIndex = (long)i * (long)species.length() * 4L;
            FloatVector vec1 = FloatVector.fromMemorySegment((VectorSpecies)species, (MemorySegment)vector1.getSegment(), (long)offsetIndex, (ByteOrder)ByteOrder.nativeOrder());
            FloatVector vec2 = FloatVector.fromMemorySegment((VectorSpecies)species, (MemorySegment)vector2.getSegment(), (long)offsetIndex, (ByteOrder)ByteOrder.nativeOrder());
            FloatVector resultVec = vec1.add((Vector)vec2);
            resultVec.intoArray(result, i * species.length());
        });
        return result;
    }

    @Test
    public void test64BitVectors() {
        VectorSpecies species = FloatVector.SPECIES_64;
        float[] result = this.parallelVectorAdd(arrayA, arrayB, (VectorSpecies<Float>)species);
        Assert.assertArrayEquals((float[])result, (float[])referenceResult.toHeapArray(), (float)0.001f);
    }

    @Test
    public void test128BitVectors() {
        VectorSpecies species = FloatVector.SPECIES_128;
        float[] result = this.parallelVectorAdd(arrayA, arrayB, (VectorSpecies<Float>)species);
        Assert.assertArrayEquals((float[])result, (float[])referenceResult.toHeapArray(), (float)0.001f);
    }

    @Test
    public void test256BitVectors() {
        VectorSpecies species = FloatVector.SPECIES_256;
        float[] result = this.parallelVectorAdd(arrayA, arrayB, (VectorSpecies<Float>)species);
        Assert.assertArrayEquals((float[])result, (float[])referenceResult.toHeapArray(), (float)0.001f);
    }

    @Test
    public void test512BitVectors() {
        VectorSpecies species = FloatVector.SPECIES_512;
        float[] result = this.parallelVectorAdd(arrayA, arrayB, (VectorSpecies<Float>)species);
        Assert.assertArrayEquals((float[])result, (float[])referenceResult.toHeapArray(), (float)0.001f);
    }
}

