package opennlp.tools.ml.maxent.quasinewton;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import opennlp.tools.ml.model.DataIndexer;

/* loaded from: input_file:WEB-INF/lib/opennlp-tools-1.9.4.jar:opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood.class */
public class ParallelNegLogLikelihood extends NegLogLikelihood {
    private int threads;
    private double[] negLogLikelihoodThread;
    private double[][] gradientThread;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:WEB-INF/lib/opennlp-tools-1.9.4.jar:opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood$ComputeTask.class */
    public abstract class ComputeTask implements Callable<ComputeTask> {
        final int threadIndex;
        final int startIndex;
        final int length;
        final double[] x;

        public ComputeTask(int i, int i2, int i3, double[] dArr) {
            this.threadIndex = i;
            this.startIndex = i2;
            this.length = i3;
            this.x = dArr;
        }
    }

    /* loaded from: input_file:WEB-INF/lib/opennlp-tools-1.9.4.jar:opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood$GradientComputeTask.class */
    class GradientComputeTask extends ComputeTask {
        final double[] expectation;

        public GradientComputeTask(int i, int i2, int i3, double[] dArr) {
            super(i, i2, i3, dArr);
            this.expectation = new double[ParallelNegLogLikelihood.this.numOutcomes];
        }

        @Override // java.util.concurrent.Callable
        /* renamed from: call, reason: merged with bridge method [inline-methods] */
        public ComputeTask call2() {
            Arrays.fill(ParallelNegLogLikelihood.this.gradientThread[this.threadIndex], 0.0d);
            for (int i = this.startIndex; i < this.startIndex + this.length; i++) {
                for (int i2 = 0; i2 < ParallelNegLogLikelihood.this.numOutcomes; i2++) {
                    this.expectation[i2] = 0.0d;
                    for (int i3 = 0; i3 < ParallelNegLogLikelihood.this.contexts[i].length; i3++) {
                        int indexOf = ParallelNegLogLikelihood.this.indexOf(i2, ParallelNegLogLikelihood.this.contexts[i][i3]);
                        double d = ParallelNegLogLikelihood.this.values != null ? ParallelNegLogLikelihood.this.values[i][i3] : 1.0d;
                        double[] dArr = this.expectation;
                        int i4 = i2;
                        dArr[i4] = dArr[i4] + (d * this.x[indexOf]);
                    }
                }
                double logSumOfExps = opennlp.tools.ml.ArrayMath.logSumOfExps(this.expectation);
                for (int i5 = 0; i5 < ParallelNegLogLikelihood.this.numOutcomes; i5++) {
                    this.expectation[i5] = StrictMath.exp(this.expectation[i5] - logSumOfExps);
                }
                int i6 = 0;
                while (i6 < ParallelNegLogLikelihood.this.numOutcomes) {
                    int i7 = ParallelNegLogLikelihood.this.outcomeList[i] == i6 ? 1 : 0;
                    for (int i8 = 0; i8 < ParallelNegLogLikelihood.this.contexts[i].length; i8++) {
                        int indexOf2 = ParallelNegLogLikelihood.this.indexOf(i6, ParallelNegLogLikelihood.this.contexts[i][i8]);
                        double d2 = ParallelNegLogLikelihood.this.values != null ? ParallelNegLogLikelihood.this.values[i][i8] : 1.0d;
                        double[] dArr2 = ParallelNegLogLikelihood.this.gradientThread[this.threadIndex];
                        dArr2[indexOf2] = dArr2[indexOf2] + (d2 * (this.expectation[i6] - i7) * ParallelNegLogLikelihood.this.numTimesEventsSeen[i]);
                    }
                    i6++;
                }
            }
            return this;
        }
    }

    /* loaded from: input_file:WEB-INF/lib/opennlp-tools-1.9.4.jar:opennlp/tools/ml/maxent/quasinewton/ParallelNegLogLikelihood$NegLLComputeTask.class */
    class NegLLComputeTask extends ComputeTask {
        final double[] tempSums;

        public NegLLComputeTask(int i, int i2, int i3, double[] dArr) {
            super(i, i2, i3, dArr);
            this.tempSums = new double[ParallelNegLogLikelihood.this.numOutcomes];
        }

        @Override // java.util.concurrent.Callable
        /* renamed from: call, reason: merged with bridge method [inline-methods] */
        public ComputeTask call2() {
            ParallelNegLogLikelihood.this.negLogLikelihoodThread[this.threadIndex] = 0.0d;
            for (int i = this.startIndex; i < this.startIndex + this.length; i++) {
                for (int i2 = 0; i2 < ParallelNegLogLikelihood.this.numOutcomes; i2++) {
                    this.tempSums[i2] = 0.0d;
                    for (int i3 = 0; i3 < ParallelNegLogLikelihood.this.contexts[i].length; i3++) {
                        int indexOf = ParallelNegLogLikelihood.this.indexOf(i2, ParallelNegLogLikelihood.this.contexts[i][i3]);
                        double d = ParallelNegLogLikelihood.this.values != null ? ParallelNegLogLikelihood.this.values[i][i3] : 1.0d;
                        double[] dArr = this.tempSums;
                        int i4 = i2;
                        dArr[i4] = dArr[i4] + (d * this.x[indexOf]);
                    }
                }
                double logSumOfExps = opennlp.tools.ml.ArrayMath.logSumOfExps(this.tempSums);
                int i5 = ParallelNegLogLikelihood.this.outcomeList[i];
                double[] dArr2 = ParallelNegLogLikelihood.this.negLogLikelihoodThread;
                int i6 = this.threadIndex;
                dArr2[i6] = dArr2[i6] - ((this.tempSums[i5] - logSumOfExps) * ParallelNegLogLikelihood.this.numTimesEventsSeen[i]);
            }
            return this;
        }
    }

    public ParallelNegLogLikelihood(DataIndexer dataIndexer, int i) {
        super(dataIndexer);
        if (i <= 0) {
            throw new IllegalArgumentException("Number of threads must 1 or larger");
        }
        this.threads = i;
        this.negLogLikelihoodThread = new double[i];
        this.gradientThread = new double[i][this.dimension];
    }

    @Override // opennlp.tools.ml.maxent.quasinewton.NegLogLikelihood, opennlp.tools.ml.maxent.quasinewton.Function
    public double valueAt(double[] dArr) {
        if (dArr.length != this.dimension) {
            throw new IllegalArgumentException("x is invalid, its dimension is not equal to domain dimension.");
        }
        computeInParallel(dArr, NegLLComputeTask.class);
        double d = 0.0d;
        for (int i = 0; i < this.threads; i++) {
            d += this.negLogLikelihoodThread[i];
        }
        return d;
    }

    @Override // opennlp.tools.ml.maxent.quasinewton.NegLogLikelihood, opennlp.tools.ml.maxent.quasinewton.Function
    public double[] gradientAt(double[] dArr) {
        if (dArr.length != this.dimension) {
            throw new IllegalArgumentException("x is invalid, its dimension is not equal to the function.");
        }
        computeInParallel(dArr, GradientComputeTask.class);
        for (int i = 0; i < this.dimension; i++) {
            this.gradient[i] = 0.0d;
            for (int i2 = 0; i2 < this.threads; i2++) {
                double[] dArr2 = this.gradient;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + this.gradientThread[i2][i];
            }
        }
        return this.gradient;
    }

    private void computeInParallel(double[] dArr, Class<? extends ComputeTask> cls) {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.threads, runnable -> {
            Thread thread = new Thread(runnable);
            thread.setName("opennlp.tools.ml.maxent.quasinewton.ParallelNegLogLikelihood.computeInParallel()");
            thread.setDaemon(true);
            return thread;
        });
        int i = this.numContexts / this.threads;
        int i2 = this.numContexts % this.threads;
        try {
            Constructor<? extends ComputeTask> constructor = cls.getConstructor(ParallelNegLogLikelihood.class, Integer.TYPE, Integer.TYPE, Integer.TYPE, double[].class);
            ArrayList arrayList = new ArrayList();
            for (int i3 = 0; i3 < this.threads; i3++) {
                if (i3 != this.threads - 1) {
                    arrayList.add(newFixedThreadPool.submit(constructor.newInstance(this, Integer.valueOf(i3), Integer.valueOf(i3 * i), Integer.valueOf(i), dArr)));
                } else {
                    arrayList.add(newFixedThreadPool.submit(constructor.newInstance(this, Integer.valueOf(i3), Integer.valueOf(i3 * i), Integer.valueOf(i + i2), dArr)));
                }
            }
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                ((Future) it.next()).get();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        newFixedThreadPool.shutdown();
    }
}
