/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.benchmarks.sgemm;

import java.util.Random;
import uk.ac.manchester.tornado.api.GridScheduler;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.WorkerGrid;
import uk.ac.manchester.tornado.api.WorkerGrid2D;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.benchmarks.BenchmarkDriver;
import uk.ac.manchester.tornado.benchmarks.LinearAlgebraArrays;

public class SgemmTornado
extends BenchmarkDriver {
    private final int m;
    private final int n;
    private WorkerGrid worker;
    private FloatArray a;
    private FloatArray b;
    private FloatArray c;
    private GridScheduler grid;
    private boolean USE_GRID = Boolean.parseBoolean(TornadoRuntimeProvider.getProperty((String)"usegrid", (String)"False"));

    public SgemmTornado(int iterations, int m, int n) {
        super(iterations);
        this.m = m;
        this.n = n;
    }

    @Override
    public void setUp() {
        int i;
        this.a = new FloatArray(this.m * this.n);
        this.b = new FloatArray(this.m * this.n);
        this.c = new FloatArray(this.m * this.n);
        Random random = new Random();
        for (i = 0; i < this.m; ++i) {
            this.a.set(i * (this.m + 1), 1.0f);
        }
        for (i = 0; i < this.m * this.n; ++i) {
            this.b.set(i, random.nextFloat());
        }
        if (this.USE_GRID) {
            this.worker = new WorkerGrid2D(this.m, this.n);
            this.worker.setLocalWork(16L, 16L, 1L);
            this.grid = new GridScheduler();
            this.grid.addWorkerGrid("benchmark.sgemm", this.worker);
        }
        this.taskGraph = new TaskGraph("benchmark");
        this.taskGraph.transferToDevice(1, new Object[]{this.a, this.b});
        this.taskGraph.task("sgemm", LinearAlgebraArrays::sgemm, (Object)this.m, (Object)this.n, (Object)this.n, (Object)this.a, (Object)this.b, (Object)this.c);
        this.taskGraph.transferToHost(1, new Object[]{this.c});
        this.executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{this.taskGraph.snapshot()});
        this.executionPlan.withPreCompilation();
    }

    @Override
    public void tearDown() {
        this.executionResult.getProfilerResult().dumpProfiles();
        this.a = null;
        this.b = null;
        this.c = null;
        this.executionPlan.resetDevice();
        super.tearDown();
    }

    @Override
    public void runBenchmark(TornadoDevice device) {
        if (this.grid != null) {
            this.executionPlan.withGridScheduler(this.grid);
        }
        this.executionResult = this.executionPlan.withDevice(device).execute();
    }

    @Override
    public boolean validate(TornadoDevice device) {
        FloatArray result = new FloatArray(this.m * this.n);
        boolean val = true;
        this.runBenchmark(device);
        this.executionPlan.clearProfiles();
        LinearAlgebraArrays.sgemm(this.m, this.n, this.m, this.a, this.b, result);
        block0: for (int i = 0; i < this.n; ++i) {
            for (int j = 0; j < this.n; ++j) {
                if (!((double)TornadoMath.abs((float)(result.get(i * this.n + j) - this.c.get(i * this.n + j))) > 0.01)) continue;
                val = false;
                continue block0;
            }
        }
        System.out.printf("Number validation: " + val + "\n", new Object[0]);
        return val;
    }
}

