/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.Arrays;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
import org.apache.flink.runtime.checkpoint.StateHandleDummyUtil;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.OperatorStreamStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.util.TestLogger;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;

public class StateAssignmentOperationTest
extends TestLogger {
    private static final int MAX_P = 256;

    @Test
    public void testRepartitionSplitDistributeStates() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(1);
        metaInfoMap1.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[30]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(1);
        metaInfoMap2.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 15L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[40]));
        operatorState.putState(1, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh2).build());
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testRepartitionUnionState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap1.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap1.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{22L, 44L}, OperatorStateHandle.Mode.UNION));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[50]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(1);
        metaInfoMap2.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[20]));
        operatorState.putState(1, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh2).build());
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testRepartitionBroadcastState() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap1.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L, 20L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 40L, 50L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[60]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(2);
        metaInfoMap2.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{0L, 10L, 20L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap2.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 40L, 50L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[60]));
        operatorState.putState(1, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh2).build());
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID);
    }

    @Test
    public void testReDistributeCombinedPartitionableStates() {
        OperatorID operatorID = new OperatorID();
        OperatorState operatorState = new OperatorState(operatorID, 2, 4);
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap1 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(6);
        metaInfoMap1.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap1.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{22L, 44L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap1.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{52L, 63L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        metaInfoMap1.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{67L, 74L, 75L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{77L, 88L, 92L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap1.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{101L, 123L, 127L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh1 = new OperatorStreamStateHandle(metaInfoMap1, (StreamStateHandle)new ByteStreamStateHandle("test1", new byte[130]));
        operatorState.putState(0, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh1).build());
        HashMap<String, OperatorStateHandle.StateMetaInfo> metaInfoMap2 = new HashMap<String, OperatorStateHandle.StateMetaInfo>(3);
        metaInfoMap2.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.UNION));
        metaInfoMap2.put("t-4", new OperatorStateHandle.StateMetaInfo(new long[]{20L, 27L, 28L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap2.put("t-5", new OperatorStateHandle.StateMetaInfo(new long[]{30L, 44L, 48L}, OperatorStateHandle.Mode.BROADCAST));
        metaInfoMap2.put("t-6", new OperatorStateHandle.StateMetaInfo(new long[]{57L, 79L, 83L}, OperatorStateHandle.Mode.BROADCAST));
        OperatorStreamStateHandle osh2 = new OperatorStreamStateHandle(metaInfoMap2, (StreamStateHandle)new ByteStreamStateHandle("test2", new byte[86]));
        operatorState.putState(1, OperatorSubtaskState.builder().setManagedOperatorState((OperatorStateHandle)osh2).build());
        this.verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 3);
        this.verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 1);
        this.verifyCombinedPartitionableStateRescale(operatorState, operatorID, 2, 2);
    }

    private void verifyAndCollectStateInfo(OperatorState operatorState, OperatorID operatorID, int oldParallelism, int newParallelism, Map<String, Integer> stateInfoCounts) {
        HashMap newManagedOperatorStates = new HashMap();
        StateAssignmentOperation.reDistributePartitionableStates(Collections.singletonMap(operatorID, operatorState), (int)newParallelism, OperatorSubtaskState::getManagedOperatorState, (OperatorStateRepartitioner)RoundRobinOperatorStateRepartitioner.INSTANCE, newManagedOperatorStates);
        for (List operatorStateHandles : newManagedOperatorStates.values()) {
            EnumMap stateModeOffsets = new EnumMap(OperatorStateHandle.Mode.class);
            for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
                stateModeOffsets.put(mode, new HashMap());
            }
            for (OperatorStateHandle operatorStateHandle : operatorStateHandles) {
                for (Map.Entry stateNameToMetaInfo : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
                    String stateName = (String)stateNameToMetaInfo.getKey();
                    stateInfoCounts.merge(stateName, 1, (count, inc) -> count + inc);
                    OperatorStateHandle.StateMetaInfo stateMetaInfo = (OperatorStateHandle.StateMetaInfo)stateNameToMetaInfo.getValue();
                    ((Map)stateModeOffsets.get(stateMetaInfo.getDistributionMode())).merge(stateName, stateMetaInfo.getOffsets().length, (count, inc) -> count + inc);
                }
            }
            for (Map.Entry entry : stateModeOffsets.entrySet()) {
                OperatorStateHandle.Mode mode = (OperatorStateHandle.Mode)entry.getKey();
                Map stateOffsets = (Map)entry.getValue();
                if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals((Object)mode)) {
                    if (oldParallelism < newParallelism) {
                        stateOffsets.values().forEach(length -> Assert.assertEquals((long)1L, (long)length.intValue()));
                        continue;
                    }
                    stateOffsets.values().forEach(length -> Assert.assertEquals((long)2L, (long)length.intValue()));
                    continue;
                }
                if (OperatorStateHandle.Mode.UNION.equals((Object)mode)) {
                    stateOffsets.values().forEach(length -> Assert.assertEquals((long)2L, (long)length.intValue()));
                    continue;
                }
                stateOffsets.values().forEach(length -> Assert.assertEquals((long)3L, (long)length.intValue()));
            }
        }
    }

    private void verifyOneKindPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID) {
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 3);
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 1);
        this.verifyOneKindPartitionableStateRescale(operatorState, operatorID, 2, 2);
    }

    private void verifyOneKindPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID, int oldParallelism, int newParallelism) {
        HashMap<String, Integer> stateInfoCounts = new HashMap<String, Integer>();
        this.verifyAndCollectStateInfo(operatorState, operatorID, oldParallelism, newParallelism, stateInfoCounts);
        Assert.assertEquals((long)2L, (long)stateInfoCounts.size());
        if (stateInfoCounts.containsKey("t-1")) {
            if (oldParallelism < newParallelism) {
                Assert.assertEquals((long)2L, (long)((Integer)stateInfoCounts.get("t-1")).intValue());
                Assert.assertEquals((long)2L, (long)((Integer)stateInfoCounts.get("t-2")).intValue());
            } else {
                Assert.assertEquals((long)1L, (long)((Integer)stateInfoCounts.get("t-1")).intValue());
                Assert.assertEquals((long)1L, (long)((Integer)stateInfoCounts.get("t-2")).intValue());
            }
        }
        if (stateInfoCounts.containsKey("t-3")) {
            Assert.assertEquals((long)(2 * newParallelism), (long)((Integer)stateInfoCounts.get("t-3")).intValue());
            Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-4")).intValue());
        }
        if (stateInfoCounts.containsKey("t-5")) {
            Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-5")).intValue());
            Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-6")).intValue());
        }
    }

    private void verifyCombinedPartitionableStateRescale(OperatorState operatorState, OperatorID operatorID, int oldParallelism, int newParallelism) {
        HashMap<String, Integer> stateInfoCounts = new HashMap<String, Integer>();
        this.verifyAndCollectStateInfo(operatorState, operatorID, oldParallelism, newParallelism, stateInfoCounts);
        Assert.assertEquals((long)6L, (long)stateInfoCounts.size());
        Assert.assertEquals((long)(2 * newParallelism), (long)((Integer)stateInfoCounts.get("t-1")).intValue());
        Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-2")).intValue());
        if (oldParallelism < newParallelism) {
            Assert.assertEquals((long)2L, (long)((Integer)stateInfoCounts.get("t-3")).intValue());
        } else {
            Assert.assertEquals((long)1L, (long)((Integer)stateInfoCounts.get("t-3")).intValue());
        }
        Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-4")).intValue());
        Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-5")).intValue());
        Assert.assertEquals((long)newParallelism, (long)((Integer)stateInfoCounts.get("t-6")).intValue());
    }

    @Test
    public void testChannelStateAssignmentStability() throws JobException, JobExecutionException {
        int numOperators = 10;
        int numSubTasks = 100;
        List<OperatorID> operatorIds = this.buildOperatorIds(numOperators);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, numSubTasks, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, numSubTasks);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        for (OperatorID operatorId : operatorIds) {
            for (int subtaskIdx = 0; subtaskIdx < numSubTasks; ++subtaskIdx) {
                Assert.assertEquals((Object)states.get(operatorId).getState(subtaskIdx), (Object)this.getAssignedState(vertices.get(operatorId), operatorId, subtaskIdx));
            }
        }
    }

    @Test
    public void testChannelStateAssignmentDownscaling() throws JobException, JobExecutionException {
        List<OperatorID> operatorIds = this.buildOperatorIds(2);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, 3);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, 2, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        for (OperatorID operatorId : operatorIds) {
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getInputChannelState, 0, 1);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getInputChannelState, 1, 2);
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getResultSubpartitionState, 0, 2);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getResultSubpartitionState, 1);
        }
        Assert.assertEquals((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.to(1, 2))), InflightDataRescalingDescriptorUtil.set(new Integer[0])), (Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 0).getOutputRescalingDescriptor());
        Assert.assertEquals((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.to(1, 2))), InflightDataRescalingDescriptorUtil.set(new Integer[0])), (Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 1).getOutputRescalingDescriptor());
        Assert.assertEquals((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.to(1))), InflightDataRescalingDescriptorUtil.set(1)), (Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 0).getInputRescalingDescriptor());
        Assert.assertEquals((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.to(1, 2), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0, 2), InflightDataRescalingDescriptorUtil.to(1))), InflightDataRescalingDescriptorUtil.set(1)), (Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 1).getInputRescalingDescriptor());
    }

    @Test
    public void testChannelStateAssignmentNoRescale() throws JobException, JobExecutionException {
        List<OperatorID> operatorIds = this.buildOperatorIds(2);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, 2);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, 2, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        for (OperatorID operatorId : operatorIds) {
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getInputChannelState, 0);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getInputChannelState, 1);
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getResultSubpartitionState, 0);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getResultSubpartitionState, 1);
        }
        Assert.assertEquals((Object)InflightDataRescalingDescriptor.NO_RESCALE, (Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 0).getOutputRescalingDescriptor());
        Assert.assertEquals((Object)InflightDataRescalingDescriptor.NO_RESCALE, (Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 1).getOutputRescalingDescriptor());
        Assert.assertEquals((Object)InflightDataRescalingDescriptor.NO_RESCALE, (Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 0).getInputRescalingDescriptor());
        Assert.assertEquals((Object)InflightDataRescalingDescriptor.NO_RESCALE, (Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 1).getInputRescalingDescriptor());
    }

    @Test
    public void testChannelStateAssignmentUpscaling() throws JobException, JobExecutionException {
        List<OperatorID> operatorIds = this.buildOperatorIds(2);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, 2);
        Map<OperatorID, ExecutionJobVertex> vertices = this.buildVertices(operatorIds, 3, SubtaskStateMapper.RANGE, SubtaskStateMapper.ROUND_ROBIN);
        new StateAssignmentOperation(0L, new HashSet<ExecutionJobVertex>(vertices.values()), states, false).assignStates();
        for (OperatorID operatorId : operatorIds) {
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getInputChannelState, 0);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getInputChannelState, 0, 1);
            this.assertState(vertices, operatorId, states, 2, OperatorSubtaskState::getInputChannelState, 1);
            this.assertState(vertices, operatorId, states, 0, OperatorSubtaskState::getResultSubpartitionState, 0);
            this.assertState(vertices, operatorId, states, 1, OperatorSubtaskState::getResultSubpartitionState, 1);
            this.assertState(vertices, operatorId, states, 2, OperatorSubtaskState::getResultSubpartitionState, new int[0]);
        }
        Assert.assertEquals((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.to(1))), InflightDataRescalingDescriptorUtil.set(new Integer[0])), (Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 0).getOutputRescalingDescriptor());
        Assert.assertEquals((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.to(1))), InflightDataRescalingDescriptorUtil.set(new Integer[0])), (Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 1).getOutputRescalingDescriptor());
        Assert.assertEquals((Object)InflightDataRescalingDescriptor.NO_RESCALE, (Object)this.getAssignedState(vertices.get(operatorIds.get(0)), operatorIds.get(0), 2).getOutputRescalingDescriptor());
        Assert.assertEquals((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.to(new int[0]))), InflightDataRescalingDescriptorUtil.set(0, 1)), (Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 0).getInputRescalingDescriptor());
        Assert.assertEquals((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.to(0, 1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.to(new int[0]))), InflightDataRescalingDescriptorUtil.set(0, 1)), (Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 1).getInputRescalingDescriptor());
        Assert.assertEquals((Object)new InflightDataRescalingDescriptor(InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.array(InflightDataRescalingDescriptorUtil.mappings(InflightDataRescalingDescriptorUtil.to(0), InflightDataRescalingDescriptorUtil.to(1), InflightDataRescalingDescriptorUtil.to(new int[0]))), InflightDataRescalingDescriptorUtil.set(0, 1)), (Object)this.getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 2).getInputRescalingDescriptor());
    }

    private void assertState(Map<OperatorID, ExecutionJobVertex> vertices, OperatorID operatorId, Map<OperatorID, OperatorState> states, int newSubtaskIndex, Function<OperatorSubtaskState, StateObjectCollection<?>> extractor, int ... oldSubtaskIndexes) {
        OperatorSubtaskState subState = this.getAssignedState(vertices.get(operatorId), operatorId, newSubtaskIndex);
        Assert.assertThat(extractor.apply(subState), (Matcher)Matchers.containsInAnyOrder((Object[])Arrays.stream(oldSubtaskIndexes).boxed().flatMap(oldIndex -> ((StateObjectCollection)extractor.apply(((OperatorState)states.get(operatorId)).getState(oldIndex.intValue()))).stream()).toArray()));
    }

    @Test
    public void assigningStatesShouldWorkWithUserDefinedOperatorIdsAsWell() {
        int numSubTasks = 1;
        OperatorID operatorId = new OperatorID();
        OperatorID userDefinedOperatorId = new OperatorID();
        List<OperatorID> operatorIds = Collections.singletonList(userDefinedOperatorId);
        ExecutionJobVertex executionJobVertex = this.buildExecutionJobVertex(operatorId, userDefinedOperatorId, 1);
        Map<OperatorID, OperatorState> states = this.buildOperatorStates(operatorIds, numSubTasks);
        new StateAssignmentOperation(0L, Collections.singleton(executionJobVertex), states, false).assignStates();
        Assert.assertEquals((Object)states.get(userDefinedOperatorId).getState(0), (Object)this.getAssignedState(executionJobVertex, operatorId, 0));
    }

    private List<OperatorID> buildOperatorIds(int numOperators) {
        return IntStream.range(0, numOperators).mapToObj(j -> new OperatorID()).collect(Collectors.toList());
    }

    private Map<OperatorID, OperatorState> buildOperatorStates(List<OperatorID> operatorIDs, int numSubTasks) {
        Random random = new Random();
        OperatorID lastId = operatorIDs.get(operatorIDs.size() - 1);
        return operatorIDs.stream().collect(Collectors.toMap(Function.identity(), operatorID -> {
            OperatorState state = new OperatorState(operatorID, numSubTasks, 256);
            for (int i = 0; i < numSubTasks; ++i) {
                state.putState(i, OperatorSubtaskState.builder().setManagedOperatorState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewOperatorStateHandle(10, random), StateHandleDummyUtil.createNewOperatorStateHandle(10, random)))).setRawOperatorState(new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewOperatorStateHandle(10, random), StateHandleDummyUtil.createNewOperatorStateHandle(10, random)))).setManagedKeyedState(StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewKeyedStateHandle(KeyGroupRange.of((int)i, (int)i)))).setRawKeyedState(StateObjectCollection.singleton((StateObject)StateHandleDummyUtil.createNewKeyedStateHandle(KeyGroupRange.of((int)i, (int)i)))).setInputChannelState(operatorID == operatorIDs.get(0) ? StateObjectCollection.empty() : new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewInputChannelStateHandle(10, random), StateHandleDummyUtil.createNewInputChannelStateHandle(10, random)))).setResultSubpartitionState(operatorID == lastId ? StateObjectCollection.empty() : new StateObjectCollection(Arrays.asList(StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random), StateHandleDummyUtil.createNewResultSubpartitionStateHandle(10, random)))).build());
            }
            return state;
        }));
    }

    private Map<OperatorID, ExecutionJobVertex> buildVertices(List<OperatorID> operatorIds, int parallelism, SubtaskStateMapper downstreamRescaler, SubtaskStateMapper upstreamRescaler) throws JobException, JobExecutionException {
        JobVertex[] jobVertices = (JobVertex[])operatorIds.stream().map(id -> {
            JobVertex jobVertex = this.createJobVertex((OperatorID)id, (OperatorID)id, parallelism);
            return jobVertex;
        }).toArray(JobVertex[]::new);
        for (int index = 1; index < jobVertices.length; ++index) {
            JobEdge jobEdge = jobVertices[index].connectNewDataSetAsInput(jobVertices[index - 1], DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
            jobEdge.setDownstreamSubtaskStateMapper(downstreamRescaler);
            jobEdge.setUpstreamSubtaskStateMapper(upstreamRescaler);
        }
        JobGraph jobGraph = JobGraphTestUtils.streamingJobGraph(jobVertices);
        DefaultExecutionGraph eg = TestingDefaultExecutionGraphBuilder.newBuilder().setJobGraph(jobGraph).build();
        return Arrays.stream(jobVertices).collect(Collectors.toMap(jobVertex -> ((OperatorIDPair)jobVertex.getOperatorIDs().get(0)).getGeneratedOperatorID(), arg_0 -> StateAssignmentOperationTest.lambda$buildVertices$12((ExecutionGraph)eg, arg_0)));
    }

    private ExecutionJobVertex buildExecutionJobVertex(OperatorID operatorID, OperatorID userDefinedOperatorId, int parallelism) {
        try {
            JobVertex jobVertex = this.createJobVertex(operatorID, userDefinedOperatorId, parallelism);
            return ExecutionGraphTestUtils.getExecutionJobVertex(jobVertex);
        }
        catch (Exception e) {
            throw new AssertionError("Cannot create ExecutionJobVertex", e);
        }
    }

    private JobVertex createJobVertex(OperatorID operatorID, OperatorID userDefinedOperatorId, int parallelism) {
        JobVertex jobVertex = new JobVertex(operatorID.toHexString(), new JobVertexID(), Collections.singletonList(OperatorIDPair.of((OperatorID)operatorID, (OperatorID)userDefinedOperatorId)));
        jobVertex.setInvokableClass(NoOpInvokable.class);
        jobVertex.setParallelism(parallelism);
        return jobVertex;
    }

    private OperatorSubtaskState getAssignedState(ExecutionJobVertex executionJobVertex, OperatorID operatorId, int subtaskIdx) {
        return executionJobVertex.getTaskVertices()[subtaskIdx].getCurrentExecutionAttempt().getTaskRestore().getTaskStateSnapshot().getSubtaskStateByOperatorID(operatorId);
    }

    private static /* synthetic */ ExecutionJobVertex lambda$buildVertices$12(ExecutionGraph eg, JobVertex jobVertex) {
        try {
            return eg.getJobVertex(jobVertex.getID());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}

