/*
 * Decompiled with CFR 0.152.
 */
package marytts.machinelearning;

import java.awt.Color;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import javax.swing.JFrame;
import marytts.machinelearning.PolynomialCluster;
import marytts.signalproc.display.FunctionGraph;
import marytts.util.math.Polynomial;

public class PolynomialHierarchicalClusteringTrainer {
    private static final double INFINITE = 1.0E7;
    private static final int CLUSTER_DEFAULT_SIZE = 5;
    private HashSet<String> dataPointSet;
    private ArrayList<Cluster> clusterList;
    private HashMap<String, Double> distanceTableMap;
    private boolean isSimilarityMeasure;
    private double MINDISTANCE;
    Polynomial[] polynomials;

    public PolynomialHierarchicalClusteringTrainer(Polynomial[] polynomials) {
        if (polynomials == null) {
            throw new NullPointerException("Input polynomial array should not be null");
        }
        if (polynomials.length <= 2) {
            throw new IllegalArgumentException("Number of samples for clustering should be more than two.");
        }
        this.dataPointSet = new HashSet();
        this.distanceTableMap = new HashMap();
        this.clusterList = new ArrayList();
        this.polynomials = polynomials;
        this.setSimilarityMeasure(true);
        this.computeSampleDistances();
        this.initializeClustering();
    }

    private double getClusterDistance(Cluster xCluster, Cluster yCluster, String linkageType) {
        if (xCluster == null || yCluster == null) {
            throw new NullPointerException("Input clusters should not be null");
        }
        if (!("Short".equals(linkageType) || "Complete".equals(linkageType) || "Average".equals(linkageType))) {
            throw new IllegalArgumentException("Only Short, Complete, or Average linkage clustering supported");
        }
        ArrayList<String> xPoints = xCluster.getAllDataPoints();
        ArrayList<String> yPoints = yCluster.getAllDataPoints();
        ArrayList<Double> distanceList = new ArrayList<Double>();
        double distance = 0.0;
        int nDistances = 0;
        int i = 0;
        while (i < xPoints.size()) {
            int j = 0;
            while (j < yPoints.size()) {
                String xyDistance = String.valueOf(xPoints.get(i)) + "_" + yPoints.get(j);
                if (this.distanceTableMap.containsKey(xyDistance)) {
                    distanceList.add(this.distanceTableMap.get(xyDistance));
                    distance = this.distanceTableMap.get(xyDistance) + distance;
                    ++nDistances;
                } else {
                    xyDistance = String.valueOf(yPoints.get(j)) + "_" + xPoints.get(i);
                    if (this.distanceTableMap.containsKey(xyDistance)) {
                        distanceList.add(this.distanceTableMap.get(xyDistance));
                        distance = this.distanceTableMap.get(xyDistance) + distance;
                        ++nDistances;
                    }
                }
                ++j;
            }
            ++i;
        }
        if (linkageType.equals("Short")) {
            Double[] data = distanceList.toArray(new Double[0]);
            double min = Double.NaN;
            int i2 = 0;
            while (i2 < data.length) {
                if (!Double.isNaN(data[i2]) && (Double.isNaN(min) || data[i2] < min)) {
                    min = data[i2];
                }
                ++i2;
            }
            return min;
        }
        if (linkageType.equals("Complete")) {
            Double[] data = distanceList.toArray(new Double[0]);
            double max = Double.NaN;
            int i3 = 0;
            while (i3 < data.length) {
                if (!Double.isNaN(data[i3]) && (Double.isNaN(max) || data[i3] > max)) {
                    max = data[i3];
                }
                ++i3;
            }
            return max;
        }
        return distance / (double)nDistances;
    }

    private void initializeClustering() {
        assert (this.dataPointSet != null);
        assert (this.clusterList != null);
        Iterator<String> it = this.dataPointSet.iterator();
        while (it.hasNext()) {
            ArrayList<String> dataSet = new ArrayList<String>();
            dataSet.add(it.next());
            Cluster aCluster = new Cluster(dataSet);
            this.clusterList.add(aCluster);
        }
    }

    private void computeSampleDistances() {
        assert (this.polynomials != null);
        assert (this.polynomials.length > 2);
        assert (this.dataPointSet != null);
        assert (this.distanceTableMap != null);
        int observations = this.polynomials.length;
        this.polynomials[0].getOrder();
        int[] cfr_ignored_0 = new int[observations];
        double[][] dist = new double[observations][observations];
        int i = 0;
        while (i < observations) {
            this.dataPointSet.add("" + i);
            int j = 0;
            while (j < observations) {
                dist[i][j] = Polynomial.polynomialPearsonProductMomentCorr(this.polynomials[i].coeffs, this.polynomials[j].coeffs);
                this.distanceTableMap.put(String.valueOf(i) + "_" + j, new Double(dist[i][j]));
                ++j;
            }
            ++i;
        }
    }

    private boolean hasSimilarityMeasure() {
        return this.isSimilarityMeasure;
    }

    private double getClusterDistance(Cluster xCluster, Cluster yCluster) {
        return this.getClusterDistance(xCluster, yCluster, "Average");
    }

    private void clustering() {
        this.clustering(5, "Average");
    }

    private void clustering(int tagetClusterSize) {
        this.clustering(tagetClusterSize, "Average");
    }

    private void clustering(int tagetClusterSize, String linkageType) {
        assert (this.clusterList != null);
        int minClusterOne = 0;
        int minClusterTwo = 0;
        int i = this.clusterList.size();
        while (i > tagetClusterSize) {
            double minDistance = this.MINDISTANCE;
            int j = 0;
            while (j < this.clusterList.size()) {
                Cluster clusterOne = this.clusterList.get(j);
                int k = j + 1;
                while (k < this.clusterList.size()) {
                    Cluster clusterTwo = this.clusterList.get(k);
                    double distance = this.getClusterDistance(clusterOne, clusterTwo, linkageType);
                    if (this.hasSimilarityMeasure()) {
                        if (distance < minDistance) {
                            minDistance = distance;
                            minClusterOne = j;
                            minClusterTwo = k;
                        }
                    } else if (distance > minDistance) {
                        minDistance = distance;
                        minClusterOne = j;
                        minClusterTwo = k;
                    }
                    ++k;
                }
                ++j;
            }
            Cluster clusterOne = this.clusterList.get(minClusterOne);
            Cluster clusterTwo = this.clusterList.get(minClusterTwo);
            clusterOne.mergeCluster(clusterTwo);
            this.clusterList.remove(clusterTwo);
            --i;
        }
        this.printClusterData();
    }

    private void printClusterData() {
        assert (this.clusterList != null);
        System.out.println("Total No of Clusters: " + this.clusterList.size());
        Iterator<Cluster> it = this.clusterList.iterator();
        int noCluster = 1;
        while (it.hasNext()) {
            Cluster aCluster = it.next();
            ArrayList<String> listPoints = aCluster.getAllDataPoints();
            System.out.println("Cluster Number : " + noCluster);
            int i = 0;
            while (i < listPoints.size()) {
                System.out.print(String.valueOf(listPoints.get(i)) + " ");
                ++i;
            }
            System.out.println();
            ++noCluster;
        }
    }

    private void setSimilarityMeasure(boolean isSimilarityMeasure) {
        this.isSimilarityMeasure = isSimilarityMeasure;
        this.MINDISTANCE = this.isSimilarityMeasure ? 1.0E7 : -1.0E7;
    }

    public PolynomialCluster[] train(int tagetClusterSize, String linkageType) {
        if (this.clusterList.size() <= tagetClusterSize) {
            throw new IllegalArgumentException("taget cluster size should be less than number of samples");
        }
        if (!("Short".equals(linkageType) || "Complete".equals(linkageType) || "Average".equals(linkageType))) {
            throw new IllegalArgumentException("Only Short, Complete, or Average linkage clustering supported");
        }
        this.clustering(tagetClusterSize, linkageType);
        PolynomialCluster[] clusters = new PolynomialCluster[tagetClusterSize];
        this.clusterList.size();
        assert (this.clusterList.size() == tagetClusterSize) : "After clustering, number of clusters and the target cluster size should be same, but now the number of clusters are " + this.clusterList.size();
        int i = 0;
        while (i < tagetClusterSize) {
            Cluster cl = this.clusterList.get(i);
            ArrayList<String> dataPoints = cl.getAllDataPoints();
            Polynomial[] members = new Polynomial[dataPoints.size()];
            int j = 0;
            while (j < dataPoints.size()) {
                members[j] = this.polynomials[new Integer(dataPoints.get(j))];
                ++j;
            }
            Polynomial meanMembers = Polynomial.mean(members);
            clusters[i] = new PolynomialCluster(meanMembers, members);
            ++i;
        }
        return clusters;
    }

    public static void main(String[] args) {
        int order = 3;
        int numPolynomials = 100;
        Polynomial[] ps = new Polynomial[numPolynomials];
        int i = 0;
        while (i < numPolynomials) {
            double[] coeffs = new double[order + 1];
            int c = 0;
            while (c < coeffs.length) {
                coeffs[c] = Math.random();
                ++c;
            }
            ps[i] = new Polynomial(coeffs);
            ++i;
        }
        PolynomialHierarchicalClusteringTrainer phCT = new PolynomialHierarchicalClusteringTrainer(ps);
        PolynomialCluster[] clusters = phCT.train(5, "Average");
        FunctionGraph clusterGraph = new FunctionGraph(0.0, 1.0, new double[1]);
        clusterGraph.setYMinMax(0.0, 5.0);
        clusterGraph.setPrimaryDataSeriesStyle(Color.BLUE, 2, 1);
        JFrame jf = clusterGraph.showInJFrame("", false, true);
        int i2 = 0;
        while (i2 < clusters.length) {
            double[] meanValues = clusters[i2].getMeanPolynomial().generatePolynomialValues(100, 0.0, 1.0);
            clusterGraph.updateData(0.0, 1.0 / (double)meanValues.length, meanValues);
            Polynomial[] members = clusters[i2].getClusterMembers();
            int m = 0;
            while (m < members.length) {
                double[] pred = members[m].generatePolynomialValues(meanValues.length, 0.0, 1.0);
                clusterGraph.addDataSeries(pred, Color.GRAY, 1, -1);
                jf.repaint();
                ++m;
            }
            jf.setTitle("Cluster " + (i2 + 1) + " of " + clusters.length + ": " + members.length + " members");
            jf.repaint();
            try {
                Thread.sleep(5000L);
            }
            catch (InterruptedException interruptedException) {}
            ++i2;
        }
    }

    class Cluster {
        private ArrayList<String> dataPoints;
        private int clusterSize;

        public Cluster(ArrayList<String> dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("Input dataset for a cluster should not be null");
            }
            this.dataPoints = dataSet;
            this.clusterSize = dataSet.size();
        }

        public ArrayList<String> getAllDataPoints() {
            return this.dataPoints;
        }

        public void mergeCluster(Cluster xCluster) {
            if (xCluster == null) {
                throw new NullPointerException("Input cluster should not be null");
            }
            ArrayList<String> xDataPoints = xCluster.getAllDataPoints();
            Iterator<String> it = xDataPoints.iterator();
            while (it.hasNext()) {
                this.dataPoints.add(it.next());
            }
            this.clusterSize = this.dataPoints.size();
        }
    }
}

