/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.compute;

import org.junit.Assert;
import org.junit.Test;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.KernelContext;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid1D;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class MMwithBytes
extends TornadoTestBase {
    @Test
    public void testMatrixMultiplicationWithBytes() throws TornadoExecutionPlanException {
        int i;
        int dim = 128;
        int numRows = 1024;
        ByteArray byteArrayWeights = new ByteArray(34816);
        FloatArray inputVector = new FloatArray(128);
        FloatArray outputVector = new FloatArray(1024);
        for (i = 0; i < 1024; ++i) {
            int offset = i * 34;
            byteArrayWeights.set(offset, (byte)0);
            byteArrayWeights.set(offset + 1, (byte)0);
            for (int j = 2; j < 34; ++j) {
                byteArrayWeights.set(offset + j, (byte)((float)i * 3.0f * 255.0f - 128.0f));
            }
        }
        for (i = 0; i < 128; ++i) {
            inputVector.set(i, (float)Math.random());
        }
        outputVector.init(0.0f);
        WorkerGrid1D workerGrid = new WorkerGrid1D(1024);
        GridScheduler gridScheduler = new GridScheduler("s0.t0", (WorkerGrid)workerGrid);
        workerGrid.setGlobalWork(1024L, 1L, 1L);
        workerGrid.setLocalWork(32L, 1L, 1L);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{byteArrayWeights, inputVector}).task("t0", MMwithBytes::matmulTornado, (Object)new KernelContext(), (Object)byteArrayWeights, (Object)inputVector, (Object)outputVector, (Object)128).transferToHost(1, new Object[]{outputVector});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.withGridScheduler(gridScheduler).execute();
        }
        for (int i2 = 0; i2 < 1024; ++i2) {
            float result = outputVector.get(i2);
            Assert.assertFalse((String)("Output contains NaN at index " + i2), (boolean)Float.isNaN(result));
            Assert.assertFalse((String)("Output contains Infinity at index " + i2), (boolean)Float.isInfinite(result));
        }
    }

    public static void matmulTornado(KernelContext context, ByteArray thisx, FloatArray that, FloatArray out, int dim1) {
        int blockSize = 32;
        int bytesPerBlock = 34;
        int idx = context.globalIdx;
        float result = 0.0f;
        int thisOffset = idx * dim1;
        for (int j = 0; j < dim1; ++j) {
            int index = thisOffset + j;
            int blockIndex = index / 32;
            int withinBlockIndex = index % 32;
            int blockOffset = blockIndex * 34;
            int scaleByte1 = thisx.get(blockOffset) & 0xFF;
            int scaleByte2 = thisx.get(blockOffset + 1) & 0xFF;
            short scaleFloat16 = (short)(scaleByte2 << 8 | scaleByte1);
            float scale = MMwithBytes.decodeFloat16(scaleFloat16);
            byte quantized = thisx.get(blockOffset + 2 + withinBlockIndex);
            result += (float)quantized * scale * that.get(j);
        }
        out.set(idx, result);
    }

    private static float decodeFloat16(short value) {
        int sign = (value & 0x8000) >>> 15;
        int exp = (value & 0x7C00) >>> 10;
        int frac = value & 0x3FF;
        if (exp == 31) {
            return sign == 0 ? Float.POSITIVE_INFINITY : Float.NEGATIVE_INFINITY;
        }
        if (exp == 0) {
            if (frac == 0) {
                return sign == 0 ? 0.0f : -0.0f;
            }
            float result = (float)frac * MMwithBytes.pow2(-24);
            return sign == 0 ? result : -result;
        }
        float result = 1.0f + (float)frac / 1024.0f;
        return sign == 0 ? result : -(result *= MMwithBytes.pow2(exp - 15));
    }

    private static float pow2(int n) {
        if (n >= 0) {
            if (n < 31) {
                return 1 << n;
            }
            return Float.POSITIVE_INFINITY;
        }
        if (n > -150) {
            return 1.0f / (float)(1 << -n);
        }
        return 0.0f;
    }

    public static void positiveInfinity(FloatArray positiveInfinity) {
        for (int i = 0; i < positiveInfinity.getSize(); ++i) {
            if (positiveInfinity.get(i) == Float.POSITIVE_INFINITY) continue;
            positiveInfinity.set(i, Float.NEGATIVE_INFINITY);
        }
    }

    @Test
    public void testPositiveInfinity() throws TornadoExecutionPlanException {
        int n = 1024;
        FloatArray positiveInfinityArray = new FloatArray(1024);
        positiveInfinityArray.init(Float.POSITIVE_INFINITY);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{positiveInfinityArray}).task("t0", MMwithBytes::positiveInfinity, (Object)positiveInfinityArray).transferToHost(1, new Object[]{positiveInfinityArray});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 1024; ++i) {
            Assert.assertEquals((float)Float.POSITIVE_INFINITY, (float)positiveInfinityArray.get(i), (float)0.0f);
        }
    }

    public static void negativeInfinity(FloatArray negativeInfinityArray) {
        for (int i = 0; i < negativeInfinityArray.getSize(); ++i) {
            if (negativeInfinityArray.get(i) == Float.NEGATIVE_INFINITY) continue;
            negativeInfinityArray.set(i, Float.POSITIVE_INFINITY);
        }
    }

    @Test
    public void testNegativeInfinity() throws TornadoExecutionPlanException {
        int n = 1024;
        FloatArray negativeInfinityArray = new FloatArray(1024);
        negativeInfinityArray.init(Float.NEGATIVE_INFINITY);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{negativeInfinityArray}).task("t0", MMwithBytes::negativeInfinity, (Object)negativeInfinityArray).transferToHost(1, new Object[]{negativeInfinityArray});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 1024; ++i) {
            Assert.assertEquals((float)Float.NEGATIVE_INFINITY, (float)negativeInfinityArray.get(i), (float)0.0f);
        }
    }

    public static void negativeInfinityAssignment(FloatArray x) {
        for (int i = 0; i < x.getSize(); ++i) {
            x.set(i, Float.NEGATIVE_INFINITY);
        }
    }

    @Test
    public void testNegativeInfinityAssignment() throws TornadoExecutionPlanException {
        int n = 1024;
        FloatArray negativeInfinityArray = new FloatArray(1024);
        negativeInfinityArray.init(2.0f);
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(0, new Object[]{negativeInfinityArray}).task("t0", MMwithBytes::negativeInfinityAssignment, (Object)negativeInfinityArray).transferToHost(1, new Object[]{negativeInfinityArray});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i = 0; i < 1024; ++i) {
            Assert.assertEquals((float)Float.NEGATIVE_INFINITY, (float)negativeInfinityArray.get(i), (float)0.0f);
        }
    }
}

