/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.paddlepaddle.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.paddlepaddle.engine.PaddlePredictor;
import ai.djl.paddlepaddle.engine.PpNDArray;
import ai.djl.paddlepaddle.engine.PpNDManager;
import ai.djl.paddlepaddle.jni.JniUtils;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Arrays;
import java.util.Iterator;

public class PpSymbolBlock
extends AbstractSymbolBlock {
    private PaddlePredictor predictor;
    private PpNDManager manager;
    private String[] inputNames;

    public PpSymbolBlock(PaddlePredictor predictor, PpNDManager manager) {
        this.predictor = predictor;
        this.manager = manager;
        this.inputNames = JniUtils.getInputNames(predictor);
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        if (this.inputNames.length != inputs.size()) {
            throw new IllegalArgumentException("Input number mismatch, requires: " + Arrays.toString(this.inputNames));
        }
        try (PpNDManager sub = this.manager.newSubManager();){
            NDList output = JniUtils.predictorForward(this.predictor, this.getInputs(sub, inputs), this.inputNames);
            NDManager inputManager = inputs.head().getManager();
            NDList ret = new NDList();
            for (NDArray array : output) {
                ret.add((Object)inputManager.from(array));
            }
            Iterator iterator = ret;
            return iterator;
        }
    }

    private PpNDArray[] getInputs(PpNDManager sub, NDList inputs) {
        PpNDArray[] inputArray = new PpNDArray[inputs.size()];
        for (int i = 0; i < inputArray.length; ++i) {
            inputArray[i] = sub.from((NDArray)inputs.get(i));
        }
        return inputArray;
    }

    public Shape[] getOutputShapes(Shape[] inputShapes) {
        return new Shape[0];
    }
}

