package org.apache.flink.iteration.operator.perround;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MetricOptions;
import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.AbstractWrapperOperator;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.proxy.state.ProxyStateSnapshotContext;
import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext;
import org.apache.flink.iteration.utils.ReflectionUtils;
import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
import org.apache.flink.runtime.state.KeyedStateBackend;
import org.apache.flink.runtime.state.OperatorStateBackend;
import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
import org.apache.flink.streaming.api.operators.StreamOperatorStateHandler;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.LatencyStats;
import org.apache.flink.util.CloseableIterable;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.BiConsumerWithException;
import org.rocksdb.RocksDB;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.class */
public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperator<T>> extends AbstractWrapperOperator<T> implements StreamOperatorStateHandler.CheckpointedStreamOperator {
    private static final Logger LOG = LoggerFactory.getLogger(AbstractPerRoundWrapperOperator.class);
    private static final String HEAP_KEYED_STATE_NAME = "org.apache.flink.runtime.state.heap.HeapKeyedStateBackend";
    private static final String ROCKSDB_KEYED_STATE_NAME = "org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend";
    private final Map<Integer, S> wrappedOperators;
    protected final LatencyStats latencyStats;
    private transient StreamOperatorStateContext streamOperatorStateContext;
    private transient StreamOperatorStateHandler stateHandler;
    private transient InternalTimeServiceManager<?> timeServiceManager;
    private transient KeySelector<?, ?> stateKeySelector1;
    private transient KeySelector<?, ?> stateKeySelector2;
    private int latestEpochWatermark;
    private ListState<Integer> parallelismState;
    private ListState<Integer> latestEpochWatermarkState;
    private ListState<Integer> pendingEpochState;
    private ListState<Integer> rawStateEpochState;

    public AbstractPerRoundWrapperOperator(StreamOperatorParameters<IterationRecord<T>> streamOperatorParameters, StreamOperatorFactory<T> streamOperatorFactory) {
        super(streamOperatorParameters, streamOperatorFactory);
        this.latestEpochWatermark = -1;
        this.wrappedOperators = new HashMap();
        this.latencyStats = initializeLatencyStats();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: getWrappedOperator */
    public S mo21getWrappedOperator(int i) {
        return getWrappedOperator(i, CloseableIterable.empty().iterator(), 0);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v14, types: [org.apache.flink.streaming.api.operators.StreamOperator] */
    private S getWrappedOperator(int i, Iterator<StatePartitionStreamProvider> it, int i2) {
        S s = this.wrappedOperators.get(Integer.valueOf(i));
        if (s != null) {
            return s;
        }
        try {
            s = (StreamOperator) StreamOperatorFactoryUtil.createOperator(InstantiationUtil.clone(this.operatorFactory), this.parameters.getContainingTask(), OperatorUtils.createWrappedOperatorConfig(this.parameters.getStreamConfig()), this.proxyOutput, this.parameters.getOperatorEventDispatcher()).f0;
            initializeStreamOperator(s, i, it, i2);
            this.wrappedOperators.put(Integer.valueOf(i), s);
            return s;
        } catch (Exception e) {
            ExceptionUtils.rethrow(e);
            return s;
        }
    }

    protected abstract void endInputAndEmitMaxWatermark(S s, int i, int i2) throws Exception;

    /* JADX INFO: Access modifiers changed from: protected */
    public void closeStreamOperator(S s, int i, int i2) throws Exception {
        setIterationContextRound(Integer.valueOf(i));
        OperatorUtils.processOperatorOrUdfIfSatisfy(s, IterationListener.class, iterationListener -> {
            notifyEpochWatermarkIncrement(iterationListener, i2);
        });
        endInputAndEmitMaxWatermark(s, i, i2);
        s.finish();
        s.close();
        setIterationContextRound(null);
        cleanupOperatorStates(i);
        if (this.stateHandler.getKeyedStateBackend() != null) {
            cleanupKeyedStates(i);
        }
    }

    @Override // org.apache.flink.iteration.operator.AbstractWrapperOperator, org.apache.flink.iteration.progresstrack.OperatorEpochWatermarkTrackerListener
    public void onEpochWatermarkIncrement(int i) throws IOException {
        Preconditions.checkState(i >= 0, "The epoch watermark should be non-negative.");
        if (i > this.latestEpochWatermark) {
            this.latestEpochWatermark = i;
            try {
                if (i < Integer.MAX_VALUE) {
                    S remove = this.wrappedOperators.remove(Integer.valueOf(i));
                    if (remove != null) {
                        closeStreamOperator(remove, i, i);
                    }
                } else {
                    ArrayList<Integer> arrayList = new ArrayList(this.wrappedOperators.keySet());
                    Collections.sort(arrayList);
                    for (Integer num : arrayList) {
                        closeStreamOperator(this.wrappedOperators.remove(num), num.intValue(), i);
                    }
                }
            } catch (Exception e) {
                ExceptionUtils.rethrow(e);
            }
        }
        super.onEpochWatermarkIncrement(i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processForEachWrappedOperator(BiConsumerWithException<Integer, S, Exception> biConsumerWithException) throws Exception {
        for (Map.Entry<Integer, S> entry : this.wrappedOperators.entrySet()) {
            biConsumerWithException.accept(entry.getKey(), entry.getValue());
        }
    }

    public void open() throws Exception {
    }

    public void initializeState(StreamTaskStateInitializer streamTaskStateInitializer) throws Exception {
        this.streamOperatorStateContext = streamTaskStateInitializer.streamOperatorStateContext(getOperatorID(), getClass().getSimpleName(), this.parameters.getProcessingTimeService(), this, this.streamConfig.getStateKeySerializer(this.containingTask.getUserCodeClassLoader()), this.containingTask.getCancelables(), this.metrics, this.streamConfig.getManagedMemoryFractionOperatorUseCaseOfSlot(ManagedMemoryUseCase.STATE_BACKEND, this.containingTask.getEnvironment().getTaskManagerInfo().getConfiguration(), this.containingTask.getUserCodeClassLoader()), isUsingCustomRawKeyedState());
        this.stateHandler = new StreamOperatorStateHandler(this.streamOperatorStateContext, this.containingTask.getExecutionConfig(), this.containingTask.getCancelables());
        this.stateHandler.initializeOperatorState(this);
        this.timeServiceManager = this.streamOperatorStateContext.internalTimerServiceManager();
        this.stateKeySelector1 = this.streamConfig.getStatePartitioner(0, this.containingTask.getUserCodeClassLoader());
        this.stateKeySelector2 = this.streamConfig.getStatePartitioner(1, this.containingTask.getUserCodeClassLoader());
    }

    public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
        this.parallelismState = stateInitializationContext.getOperatorStateStore().getUnionListState(new ListStateDescriptor("parallelism", IntSerializer.INSTANCE));
        OperatorStateUtils.getUniqueElement(this.parallelismState, "parallelism").ifPresent(num -> {
            Preconditions.checkState(num.intValue() == this.containingTask.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(), "The all-round wrapper operator is recovered with parallelism changed from " + num + " to " + this.containingTask.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks());
        });
        this.latestEpochWatermarkState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("latestEpoch", IntSerializer.INSTANCE));
        OperatorStateUtils.getUniqueElement(this.latestEpochWatermarkState, "latestEpoch").ifPresent(num2 -> {
            this.latestEpochWatermark = num2.intValue();
        });
        this.rawStateEpochState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("rawStateEpoch", Integer.class));
        List list = IteratorUtils.toList(((Iterable) this.rawStateEpochState.get()).iterator());
        this.pendingEpochState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("pendingEpochs", IntSerializer.INSTANCE));
        List list2 = IteratorUtils.toList(((Iterable) this.pendingEpochState.get()).iterator());
        Iterator<StatePartitionStreamProvider> it = stateInitializationContext.getRawOperatorStateInputs().iterator();
        int i = 0;
        Iterator it2 = list2.iterator();
        while (it2.hasNext()) {
            int intValue = ((Integer) it2.next()).intValue();
            Preconditions.checkState(i == list.size() || ((Integer) list.get(i)).intValue() >= intValue, String.format("Unexpected raw state indices %s and epochs %s", list.toString(), list2.toString()));
            int i2 = 0;
            while (i < list.size() && ((Integer) list.get(i)).intValue() == intValue) {
                i2++;
                i++;
            }
            getWrappedOperator(intValue, it, i2);
        }
    }

    @Internal
    protected boolean isUsingCustomRawKeyedState() {
        return false;
    }

    public void finish() throws Exception {
        Preconditions.checkState(this.wrappedOperators.size() == 0, "Some wrapped operators are still not closed yet: " + this.wrappedOperators.keySet());
    }

    public void close() throws Exception {
        if (this.stateHandler != null) {
            this.stateHandler.dispose();
        }
    }

    public void prepareSnapshotPreBarrier(long j) throws Exception {
        Iterator<Map.Entry<Integer, S>> it = this.wrappedOperators.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().prepareSnapshotPreBarrier(j);
        }
    }

    public OperatorSnapshotFutures snapshotState(long j, long j2, CheckpointOptions checkpointOptions, CheckpointStreamFactory checkpointStreamFactory) throws Exception {
        return this.stateHandler.snapshotState(this, Optional.ofNullable(this.timeServiceManager), this.streamConfig.getOperatorName(), j, j2, checkpointOptions, checkpointStreamFactory, isUsingCustomRawKeyedState());
    }

    public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception {
        OperatorStateCheckpointOutputStream rawOperatorStateOutput = stateSnapshotContext.getRawOperatorStateOutput();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList(this.wrappedOperators.keySet());
        Collections.sort(arrayList2);
        Iterator it = arrayList2.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            StreamOperatorStateHandler.CheckpointedStreamOperator checkpointedStreamOperator = (StreamOperator) this.wrappedOperators.get(Integer.valueOf(intValue));
            if (StreamOperatorStateHandler.CheckpointedStreamOperator.class.isAssignableFrom(checkpointedStreamOperator.getClass())) {
                checkpointedStreamOperator.snapshotState(new ProxyStateSnapshotContext(stateSnapshotContext));
                int numberOfPartitions = rawOperatorStateOutput.getNumberOfPartitions();
                while (arrayList.size() < numberOfPartitions) {
                    arrayList.add(Integer.valueOf(intValue));
                }
            }
        }
        this.parallelismState.clear();
        if (this.containingTask.getEnvironment().getTaskInfo().getIndexOfThisSubtask() == 0) {
            this.parallelismState.update(Collections.singletonList(Integer.valueOf(this.containingTask.getEnvironment().getTaskInfo().getNumberOfParallelSubtasks())));
        }
        this.latestEpochWatermarkState.update(Collections.singletonList(Integer.valueOf(this.latestEpochWatermark)));
        this.rawStateEpochState.update(arrayList);
        this.pendingEpochState.update(arrayList2);
    }

    public void setKeyContextElement1(StreamRecord streamRecord) throws Exception {
        setKeyContextElement(streamRecord, this.stateKeySelector1);
    }

    public void setKeyContextElement2(StreamRecord streamRecord) throws Exception {
        setKeyContextElement(streamRecord, this.stateKeySelector2);
    }

    private <T> void setKeyContextElement(StreamRecord<T> streamRecord, KeySelector<T, ?> keySelector) throws Exception {
        if (keySelector == null || ((IterationRecord) streamRecord.getValue()).getType() != IterationRecord.Type.RECORD) {
            return;
        }
        setCurrentKey(keySelector.getKey(streamRecord.getValue()));
    }

    public OperatorMetricGroup getMetricGroup() {
        return this.metrics;
    }

    public OperatorID getOperatorID() {
        return this.streamConfig.getOperatorID();
    }

    public void notifyCheckpointComplete(long j) throws Exception {
        Iterator<Map.Entry<Integer, S>> it = this.wrappedOperators.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().notifyCheckpointComplete(j);
        }
    }

    public void notifyCheckpointAborted(long j) throws Exception {
        Iterator<Map.Entry<Integer, S>> it = this.wrappedOperators.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().notifyCheckpointAborted(j);
        }
    }

    public void setCurrentKey(Object obj) {
        this.stateHandler.setCurrentKey(obj);
    }

    public Object getCurrentKey() {
        if (this.stateHandler == null) {
            return null;
        }
        return this.stateHandler.getCurrentKey();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void reportOrForwardLatencyMarker(LatencyMarker latencyMarker) {
        this.latencyStats.reportLatency(latencyMarker);
        this.output.emitLatencyMarker(latencyMarker);
    }

    private LatencyStats initializeLatencyStats() {
        LatencyStats.Granularity granularity;
        try {
            Configuration configuration = this.containingTask.getEnvironment().getTaskManagerInfo().getConfiguration();
            int integer = configuration.getInteger(MetricOptions.LATENCY_HISTORY_SIZE);
            if (integer <= 0) {
                LOG.warn("{} has been set to a value equal or below 0: {}. Using default.", MetricOptions.LATENCY_HISTORY_SIZE, Integer.valueOf(integer));
                integer = ((Integer) MetricOptions.LATENCY_HISTORY_SIZE.defaultValue()).intValue();
            }
            String string = configuration.getString(MetricOptions.LATENCY_SOURCE_GRANULARITY);
            try {
                granularity = LatencyStats.Granularity.valueOf(string.toUpperCase(Locale.ROOT));
            } catch (IllegalArgumentException e) {
                granularity = LatencyStats.Granularity.OPERATOR;
                LOG.warn("Configured value {} option for {} is invalid. Defaulting to {}.", new Object[]{string, MetricOptions.LATENCY_SOURCE_GRANULARITY.key(), granularity});
            }
            return new LatencyStats(this.metrics.getJobMetricGroup().addGroup("latency"), integer, this.containingTask.getIndexInSubtaskGroup(), getOperatorID(), granularity);
        } catch (Exception e2) {
            LOG.warn("An error occurred while instantiating latency metrics.", e2);
            return new LatencyStats(UnregisteredMetricGroups.createUnregisteredTaskManagerJobMetricGroup().addGroup("latency"), 1, 0, new OperatorID(), LatencyStats.Granularity.SINGLE);
        }
    }

    private void initializeStreamOperator(S s, int i, Iterator<StatePartitionStreamProvider> it, int i2) throws Exception {
        s.initializeState((operatorID, str, processingTimeService, keyContext, typeSerializer, closeableRegistry, metricGroup, d, z) -> {
            return new ProxyStreamOperatorStateContext(this.streamOperatorStateContext, getRoundStatePrefix(i), it, i2);
        });
        s.open();
    }

    private void cleanupOperatorStates(int i) {
        String roundStatePrefix = getRoundStatePrefix(i);
        OperatorStateBackend operatorStateBackend = this.stateHandler.getOperatorStateBackend();
        if (!(operatorStateBackend instanceof DefaultOperatorStateBackend)) {
            LOG.warn("Unable to cleanup the operator state {}", operatorStateBackend);
            return;
        }
        for (String str : new String[]{"registeredOperatorStates", "registeredBroadcastStates", "accessedStatesByName", "accessedBroadcastStatesByName"}) {
            ((Map) ReflectionUtils.getFieldValue(operatorStateBackend, DefaultOperatorStateBackend.class, str)).entrySet().removeIf(entry -> {
                return ((String) entry.getKey()).startsWith(roundStatePrefix);
            });
        }
    }

    private void cleanupKeyedStates(int i) {
        String roundStatePrefix = getRoundStatePrefix(i);
        KeyedStateBackend keyedStateBackend = this.stateHandler.getKeyedStateBackend();
        if (keyedStateBackend.getClass().getName().equals(HEAP_KEYED_STATE_NAME)) {
            ((Map) ReflectionUtils.getFieldValue(keyedStateBackend, HeapKeyedStateBackend.class, "registeredKVStates")).entrySet().removeIf(entry -> {
                return ((String) entry.getKey()).startsWith(roundStatePrefix);
            });
            ((Map) ReflectionUtils.getFieldValue(keyedStateBackend, AbstractKeyedStateBackend.class, "keyValueStatesByName")).entrySet().removeIf(entry2 -> {
                return ((String) entry2.getKey()).startsWith(roundStatePrefix);
            });
        } else {
            if (!keyedStateBackend.getClass().getName().equals(ROCKSDB_KEYED_STATE_NAME)) {
                LOG.warn("Unable to cleanup the keyed state {}", keyedStateBackend);
                return;
            }
            RocksDB rocksDB = (RocksDB) ReflectionUtils.getFieldValue(keyedStateBackend, RocksDBKeyedStateBackend.class, "db");
            HashMap hashMap = (HashMap) ReflectionUtils.getFieldValue(keyedStateBackend, RocksDBKeyedStateBackend.class, "kvStateInformation");
            hashMap.entrySet().stream().filter(entry3 -> {
                return ((String) entry3.getKey()).startsWith(roundStatePrefix);
            }).forEach(entry4 -> {
                try {
                    rocksDB.dropColumnFamily(((RocksDBKeyedStateBackend.RocksDbKvStateInfo) entry4.getValue()).columnFamilyHandle);
                } catch (Exception e) {
                    LOG.error("Failed to drop state {} for round {}", entry4.getKey(), Integer.valueOf(i));
                }
            });
            hashMap.entrySet().removeIf(entry5 -> {
                return ((String) entry5.getKey()).startsWith(roundStatePrefix);
            });
            ((Map) ReflectionUtils.getFieldValue(keyedStateBackend, AbstractKeyedStateBackend.class, "keyValueStatesByName")).entrySet().removeIf(entry6 -> {
                return ((String) entry6.getKey()).startsWith(roundStatePrefix);
            });
        }
    }

    private String getRoundStatePrefix(int i) {
        return "r" + i + "-";
    }

    int getLatestEpochWatermark() {
        return this.latestEpochWatermark;
    }

    public Map<Integer, S> getWrappedOperators() {
        return this.wrappedOperators;
    }
}
