/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.unittests.arrays;

import java.lang.foreign.MemorySegment;
import java.util.stream.IntStream;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.exceptions.TornadoOutOfMemoryException;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.unittests.common.TornadoTestBase;

public class TestLargeArrays
extends TornadoTestBase {
    public static boolean checkDeviceMemory() {
        long mem = TornadoRuntimeProvider.getTornadoRuntime().getDefaultDevice().getMaxGlobalMemory();
        return mem > 0xC0000000L;
    }

    public static void addAccumulator(FloatArray a, float value) {
        for (int i = 0; i < a.getSize(); ++i) {
            a.set(i, a.get(i) + value);
        }
    }

    @Override
    public void before() {
        boolean hasRequiredDeviceMemory = TestLargeArrays.checkDeviceMemory();
        Assume.assumeTrue((String)"Skipping TestLargeArrays: requires > 3GB global memory", (boolean)hasRequiredDeviceMemory);
    }

    @Test
    public void testLargeFloatArraySafe() throws TornadoExecutionPlanException {
        int numElements = 510000000;
        this.testFloatArrayWithSize(510000000);
    }

    @Test(expected=TornadoOutOfMemoryException.class)
    public void testLargeFloatArrayOverflow() throws TornadoExecutionPlanException {
        int numElements = 540000000;
        this.testFloatArrayWithSize(540000000);
    }

    private void testFloatArrayWithSize(int numElements) throws TornadoExecutionPlanException {
        FloatArray a = new FloatArray(numElements);
        IntStream.range(0, numElements).sequential().forEach(i -> a.set(i, (float)Math.random()));
        FloatArray b = FloatArray.fromSegment((MemorySegment)a.getSegment());
        float accumulator = 1.0f;
        TaskGraph taskGraph = new TaskGraph("s0").transferToDevice(1, new Object[]{a}).task("t0", TestLargeArrays::addAccumulator, (Object)a, (Object)Float.valueOf(accumulator)).transferToHost(1, new Object[]{a});
        ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
        try (TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{immutableTaskGraph});){
            executionPlan.execute();
        }
        for (int i2 = 0; i2 < a.getSize(); ++i2) {
            Assert.assertEquals((float)(b.get(i2) + accumulator), (float)a.get(i2), (float)0.01f);
        }
    }
}

