/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.gpu.context;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.HashMap;
import jcuda.CudaException;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.CUresult;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;

public class JCudaKernels {
    private static final String ptxFileName = "/cuda/kernels/SystemDS.ptx";
    private HashMap<String, CUfunction> kernels = new HashMap();
    private CUmodule module = new CUmodule();

    JCudaKernels() {
        JCudaKernels.checkResult(JCudaDriver.cuModuleLoadDataEx((CUmodule)this.module, (Pointer)JCudaKernels.initKernels(ptxFileName), (int)0, (int[])new int[0], (Pointer)Pointer.to((int[])new int[0])));
    }

    public void launchKernel(String name, ExecutionConfig config, Object ... arguments) {
        CUfunction function = this.kernels.get(name = name + LibMatrixCUDA.customKernelSuffix);
        if (function == null) {
            function = new CUfunction();
            try {
                JCudaKernels.checkResult(JCudaDriver.cuModuleGetFunction((CUfunction)function, (CUmodule)this.module, (String)name));
            }
            catch (CudaException e) {
                throw new DMLRuntimeException("Error finding the custom kernel:" + name, (Exception)((Object)e));
            }
        }
        Pointer[] kernelParams = new Pointer[arguments.length];
        for (int i = 0; i < arguments.length; ++i) {
            if (arguments[i] == null) {
                throw new DMLRuntimeException("The argument to the kernel cannot be null.");
            }
            if (arguments[i] instanceof Pointer) {
                kernelParams[i] = Pointer.to((NativePointerObject[])new NativePointerObject[]{(Pointer)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Integer) {
                kernelParams[i] = Pointer.to((int[])new int[]{(Integer)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Double) {
                kernelParams[i] = Pointer.to((double[])new double[]{(Double)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Long) {
                kernelParams[i] = Pointer.to((long[])new long[]{(Long)arguments[i]});
                continue;
            }
            if (arguments[i] instanceof Float) {
                kernelParams[i] = Pointer.to((float[])new float[]{((Float)arguments[i]).floatValue()});
                continue;
            }
            throw new DMLRuntimeException("The argument of type " + arguments[i].getClass() + " is not supported.");
        }
        JCudaKernels.checkResult(JCudaDriver.cuLaunchKernel((CUfunction)function, (int)config.gridDimX, (int)config.gridDimY, (int)config.gridDimZ, (int)config.blockDimX, (int)config.blockDimY, (int)config.blockDimZ, (int)config.sharedMemBytes, (CUstream)config.stream, (Pointer)Pointer.to((NativePointerObject[])kernelParams), null));
        if (DMLScript.SYNCHRONIZE_GPU) {
            JCuda.cudaDeviceSynchronize();
        }
    }

    public static void checkResult(int cuResult) {
        if (cuResult != 0) {
            throw new DMLRuntimeException(CUresult.stringFor((int)cuResult));
        }
    }

    /*
     * WARNING - Removed back jump from a try to a catch block - possible behaviour change.
     * Unable to fully structure code
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static Pointer initKernels(String ptxFileName) {
        block19: {
            out = null;
            try {
                block20: {
                    in = JCudaKernels.class.getResourceAsStream(ptxFileName);
                    var3_4 = null;
                    try {
                        if (in != null) {
                            out = new ByteArrayOutputStream();
                            buffer = new byte[8192];
                            while (true) {
                                if ((read = in.read(buffer)) == -1) {
                                    out.write(0);
                                    out.flush();
                                    var5_8 = Pointer.to((byte[])out.toByteArray());
                                    if (in != null) {
                                        break;
                                    }
                                    break block19;
                                }
                                out.write(buffer, 0, read);
                            }
                            if (var3_4 == null) break block20;
                        }
                        ** GOTO lbl-1000
                    }
                    catch (Throwable var4_6) {
                        var3_4 = var4_6;
                        throw var4_6;
                    }
                    catch (Throwable var7_10) {
                        if (in == null) throw var7_10;
                        if (var3_4 == null) {
                            in.close();
                            throw var7_10;
                        }
                        try {
                            in.close();
                            throw var7_10;
                        }
                        catch (Throwable var8_11) {
                            var3_4.addSuppressed(var8_11);
                            throw var7_10;
                        }
                    }
                    try {
                        in.close();
                    }
                    catch (Throwable var6_9) {
                        var3_4.addSuppressed(var6_9);
                    }
                    break block19;
                }
                in.close();
            }
            catch (IOException e) {
                throw new DMLRuntimeException("Could not initialize the kernels", e);
            }
            catch (Throwable var9_12) {
                IOUtilFunctions.closeSilently(out);
                throw var9_12;
            }
        }
        IOUtilFunctions.closeSilently(out);
        return var5_8;
lbl-1000:
        // 1 sources

        {
            throw new DMLRuntimeException("The input file " + ptxFileName + " not found. (Hint: Please compile SystemDS using -DenableGPU=true flag. Example: mvn package -DenableGPU=true).");
        }
    }
}

