/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.nodes.exec.common;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.connector.Projection;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
import org.apache.flink.table.planner.plan.nodes.exec.utils.CommonPythonUtil;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.runtime.generated.GeneratedProjection;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Preconditions;

public abstract class CommonExecPythonCalc
extends ExecNodeBase<RowData>
implements SingleTransformationTranslator<RowData> {
    public static final String FIELD_NAME_PROJECTION = "projection";
    private static final String PYTHON_SCALAR_FUNCTION_OPERATOR_NAME = "org.apache.flink.table.runtime.operators.python.scalar.PythonScalarFunctionOperator";
    private static final String EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME = "org.apache.flink.table.runtime.operators.python.scalar.EmbeddedPythonScalarFunctionOperator";
    private static final String ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME = "org.apache.flink.table.runtime.operators.python.scalar.arrow.ArrowPythonScalarFunctionOperator";
    private final List<RexNode> projection;

    public CommonExecPythonCalc(int id, ExecNodeContext context, ReadableConfig persistedConfig, List<RexNode> projection, List<InputProperty> inputProperties, RowType outputType, String description) {
        super(id, context, persistedConfig, inputProperties, (LogicalType)outputType, description);
        Preconditions.checkArgument((inputProperties.size() == 1 ? 1 : 0) != 0);
        this.projection = (List)Preconditions.checkNotNull(projection);
    }

    @Override
    protected Transformation<RowData> translateToPlanInternal(PlannerBase planner, ExecNodeConfig config) {
        ExecEdge inputEdge = this.getInputEdges().get(0);
        Transformation<?> inputTransform = inputEdge.translateToPlan(planner);
        Configuration pythonConfig = CommonPythonUtil.extractPythonConfiguration(planner.getExecEnv(), config);
        OneInputTransformation<RowData, RowData> ret = this.createPythonOneInputTransformation(inputTransform, config, pythonConfig);
        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
            ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
        }
        return ret;
    }

    private OneInputTransformation<RowData, RowData> createPythonOneInputTransformation(Transformation<RowData> inputTransform, ExecNodeConfig config, Configuration pythonConfig) {
        List<RexCall> pythonRexCalls = this.projection.stream().filter(x -> x instanceof RexCall).map(x -> (RexCall)x).collect(Collectors.toList());
        List forwardedFields = this.projection.stream().filter(x -> x instanceof RexInputRef).map(x -> ((RexInputRef)x).getIndex()).collect(Collectors.toList());
        Tuple2<int[], PythonFunctionInfo[]> extractResult = this.extractPythonScalarFunctionInfos(pythonRexCalls);
        int[] pythonUdfInputOffsets = (int[])extractResult.f0;
        PythonFunctionInfo[] pythonFunctionInfos = (PythonFunctionInfo[])extractResult.f1;
        LogicalType[] inputLogicalTypes = ((InternalTypeInfo)inputTransform.getOutputType()).toRowFieldTypes();
        InternalTypeInfo pythonOperatorInputTypeInfo = (InternalTypeInfo)inputTransform.getOutputType();
        List forwardedFieldsLogicalTypes = forwardedFields.stream().map(i -> inputLogicalTypes[i]).collect(Collectors.toList());
        List pythonCallLogicalTypes = pythonRexCalls.stream().map(node -> FlinkTypeFactory.toLogicalType(node.getType())).collect(Collectors.toList());
        ArrayList fieldsLogicalTypes = new ArrayList();
        fieldsLogicalTypes.addAll(forwardedFieldsLogicalTypes);
        fieldsLogicalTypes.addAll(pythonCallLogicalTypes);
        InternalTypeInfo pythonOperatorResultTyeInfo = InternalTypeInfo.ofFields((LogicalType[])fieldsLogicalTypes.toArray(new LogicalType[0]));
        OneInputStreamOperator<RowData, RowData> pythonOperator = this.getPythonScalarFunctionOperator(config, pythonConfig, (InternalTypeInfo<RowData>)pythonOperatorInputTypeInfo, (InternalTypeInfo<RowData>)pythonOperatorResultTyeInfo, pythonUdfInputOffsets, pythonFunctionInfos, forwardedFields.stream().mapToInt(x -> x).toArray(), pythonRexCalls.stream().anyMatch(x -> PythonUtil.containsPythonCall(x, PythonFunctionKind.PANDAS)));
        return ExecNodeUtil.createOneInputTransformation(inputTransform, this.createTransformationName(config), this.createTransformationDescription(config), pythonOperator, pythonOperatorResultTyeInfo, inputTransform.getParallelism());
    }

    private Tuple2<int[], PythonFunctionInfo[]> extractPythonScalarFunctionInfos(List<RexCall> rexCalls) {
        LinkedHashMap inputNodes = new LinkedHashMap();
        PythonFunctionInfo[] pythonFunctionInfos = rexCalls.stream().map(x -> CommonPythonUtil.createPythonFunctionInfo(x, inputNodes)).collect(Collectors.toList()).toArray(new PythonFunctionInfo[rexCalls.size()]);
        int[] udfInputOffsets = inputNodes.keySet().stream().map(x -> {
            if (x instanceof RexInputRef) {
                return ((RexInputRef)x).getIndex();
            }
            if (x instanceof RexFieldAccess) {
                return ((RexFieldAccess)x).getField().getIndex();
            }
            return null;
        }).mapToInt(i -> i).toArray();
        return Tuple2.of((Object)udfInputOffsets, (Object)pythonFunctionInfos);
    }

    private OneInputStreamOperator<RowData, RowData> getPythonScalarFunctionOperator(ExecNodeConfig config, Configuration pythonConfig, InternalTypeInfo<RowData> inputRowTypeInfo, InternalTypeInfo<RowData> outputRowTypeInfo, int[] udfInputOffsets, PythonFunctionInfo[] pythonFunctionInfos, int[] forwardedFields, boolean isArrow) {
        boolean isInProcessMode = CommonPythonUtil.isPythonWorkerInProcessMode(pythonConfig);
        Class<?> clazz = isArrow ? CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME) : (isInProcessMode ? CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME) : CommonPythonUtil.loadClass(EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME));
        RowType inputType = inputRowTypeInfo.toRowType();
        RowType outputType = outputRowTypeInfo.toRowType();
        RowType udfInputType = (RowType)Projection.of((int[])udfInputOffsets).project((LogicalType)inputType);
        RowType forwardedFieldType = (RowType)Projection.of((int[])forwardedFields).project((LogicalType)inputType);
        RowType udfOutputType = (RowType)Projection.range((int)forwardedFields.length, (int)outputType.getFieldCount()).project((LogicalType)outputType);
        try {
            if (isInProcessMode) {
                Constructor<?> ctor = clazz.getConstructor(Configuration.class, PythonFunctionInfo[].class, RowType.class, RowType.class, RowType.class, GeneratedProjection.class, GeneratedProjection.class);
                return (OneInputStreamOperator)ctor.newInstance(pythonConfig, pythonFunctionInfos, inputType, udfInputType, udfOutputType, ProjectionCodeGenerator.generateProjection(CodeGeneratorContext.apply(config.getTableConfig()), "UdfInputProjection", inputType, udfInputType, udfInputOffsets), ProjectionCodeGenerator.generateProjection(CodeGeneratorContext.apply(config.getTableConfig()), "ForwardedFieldProjection", inputType, forwardedFieldType, forwardedFields));
            }
            if (forwardedFields.length > 0) {
                Constructor<?> ctor = clazz.getConstructor(Configuration.class, PythonFunctionInfo[].class, RowType.class, RowType.class, RowType.class, int[].class, GeneratedProjection.class);
                return (OneInputStreamOperator)ctor.newInstance(pythonConfig, pythonFunctionInfos, inputType, udfInputType, udfOutputType, udfInputOffsets, ProjectionCodeGenerator.generateProjection(CodeGeneratorContext.apply(config.getTableConfig()), "ForwardedFieldProjection", inputType, forwardedFieldType, forwardedFields));
            }
            Constructor<?> ctor = clazz.getConstructor(Configuration.class, PythonFunctionInfo[].class, RowType.class, RowType.class, RowType.class, int[].class);
            return (OneInputStreamOperator)ctor.newInstance(pythonConfig, pythonFunctionInfos, inputType, udfInputType, udfOutputType, udfInputOffsets);
        }
        catch (Exception e) {
            throw new TableException("Python Scalar Function Operator constructed failed.", (Throwable)e);
        }
    }
}

