package ai.djl.pytorch.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.pytorch.jni.LibUtils;
import ai.djl.training.GradientCollector;
import ai.djl.util.Utils;
import java.io.FileNotFoundException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Iterator;
import jnr.ffi.provider.jffi.JNINativeInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:META-INF/bundled-dependencies/pytorch-engine-0.22.1.jar:ai/djl/pytorch/engine/PtEngine.class */
public final class PtEngine extends Engine {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) PtEngine.class);
    public static final String ENGINE_NAME = "PyTorch";
    static final int RANK = 2;

    private PtEngine() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Engine newInstance() {
        try {
            LibUtils.loadLibrary();
            JniUtils.setGradMode(false);
            if (Integer.getInteger("ai.djl.pytorch.num_interop_threads") != null) {
                JniUtils.setNumInteropThreads(Integer.getInteger("ai.djl.pytorch.num_interop_threads").intValue());
            }
            if (Integer.getInteger("ai.djl.pytorch.num_threads") != null) {
                JniUtils.setNumThreads(Integer.getInteger("ai.djl.pytorch.num_threads").intValue());
            }
            if (Boolean.getBoolean("ai.djl.pytorch.cudnn_benchmark")) {
                JniUtils.setBenchmarkCuDNN(true);
            }
            if ("true".equals(System.getProperty("ai.djl.pytorch.graph_optimizer", "true"))) {
                logger.info("PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization");
            }
            logger.info("Number of inter-op threads is " + JniUtils.getNumInteropThreads());
            logger.info("Number of intra-op threads is " + JniUtils.getNumThreads());
            String envOrSystemProperty = Utils.getEnvOrSystemProperty("PYTORCH_EXTRA_LIBRARY_PATH");
            if (envOrSystemProperty != null) {
                for (String str : envOrSystemProperty.split(",")) {
                    Path path = Paths.get(str, new String[0]);
                    if (Files.notExists(path, new LinkOption[0])) {
                        throw new FileNotFoundException("PyTorch extra Library not found: " + str);
                    }
                    System.load(path.toAbsolutePath().toString());
                }
            }
            return new PtEngine();
        } catch (EngineException e) {
            throw e;
        } catch (Throwable th) {
            throw new EngineException("Failed to load PyTorch native library", th);
        }
    }

    @Override // ai.djl.engine.Engine
    public Engine getAlternativeEngine() {
        return null;
    }

    @Override // ai.djl.engine.Engine
    public String getEngineName() {
        return ENGINE_NAME;
    }

    @Override // ai.djl.engine.Engine
    public int getRank() {
        return 2;
    }

    @Override // ai.djl.engine.Engine
    public String getVersion() {
        return LibUtils.getVersion();
    }

    @Override // ai.djl.engine.Engine
    public boolean hasCapability(String str) {
        return JniUtils.getFeatures().contains(str);
    }

    @Override // ai.djl.engine.Engine
    public SymbolBlock newSymbolBlock(NDManager nDManager) {
        return new PtSymbolBlock((PtNDManager) nDManager);
    }

    @Override // ai.djl.engine.Engine
    public Model newModel(String str, Device device) {
        return new PtModel(str, device);
    }

    @Override // ai.djl.engine.Engine
    public NDManager newBaseManager() {
        return PtNDManager.getSystemManager().newSubManager();
    }

    @Override // ai.djl.engine.Engine
    public NDManager newBaseManager(Device device) {
        return PtNDManager.getSystemManager().newSubManager(device);
    }

    @Override // ai.djl.engine.Engine
    public GradientCollector newGradientCollector() {
        return new PtGradientCollector();
    }

    @Override // ai.djl.engine.Engine
    public void setRandomSeed(int i) {
        super.setRandomSeed(i);
        JniUtils.setSeed(i);
    }

    @Override // ai.djl.engine.Engine
    public String toString() {
        StringBuilder sb = new StringBuilder(JNINativeInterface.GetByteArrayRegion);
        sb.append(getEngineName()).append(':').append(getVersion()).append(", capabilities: [\n");
        Iterator<String> it = JniUtils.getFeatures().iterator();
        while (it.hasNext()) {
            sb.append("\t").append(it.next()).append(",\n");
        }
        sb.append("]\nPyTorch Library: ").append(LibUtils.getLibtorchPath());
        return sb.toString();
    }
}
