package org.encog.ml.hmm.train.kmeans;

import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.MLSequenceSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ViterbiCalculator;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

/* loaded from: classes.dex */
public class TrainKMeans implements MLTrain {
    private final Clusters clusters;
    private boolean done = false;
    private int iteration;
    private HiddenMarkovModel method;
    private final HiddenMarkovModel modelHMM;
    private final MLSequenceSet sequnces;
    private final int states;
    private final MLSequenceSet training;

    public TrainKMeans(HiddenMarkovModel hiddenMarkovModel, MLSequenceSet mLSequenceSet) {
        this.method = hiddenMarkovModel;
        this.modelHMM = hiddenMarkovModel;
        this.sequnces = mLSequenceSet;
        this.states = hiddenMarkovModel.getStateCount();
        this.training = mLSequenceSet;
        this.clusters = new Clusters(this.states, mLSequenceSet);
    }

    private void learnOpdf(HiddenMarkovModel hiddenMarkovModel) {
        for (int i = 0; i < hiddenMarkovModel.getStateCount(); i++) {
            Collection<MLDataPair> cluster = this.clusters.cluster(i);
            if (cluster.size() < 1) {
                hiddenMarkovModel.setStateDistribution(i, this.modelHMM.createNewDistribution());
            } else {
                BasicMLDataSet basicMLDataSet = new BasicMLDataSet();
                Iterator<MLDataPair> it = cluster.iterator();
                while (it.hasNext()) {
                    basicMLDataSet.add(it.next());
                }
                hiddenMarkovModel.getStateDistribution(i).fit(basicMLDataSet);
            }
        }
    }

    private void learnPi(HiddenMarkovModel hiddenMarkovModel) {
        double[] dArr = new double[this.states];
        for (int i = 0; i < this.states; i++) {
            dArr[i] = 0.0d;
        }
        Iterator<MLDataSet> it = this.sequnces.getSequences().iterator();
        while (it.hasNext()) {
            int cluster = this.clusters.cluster(it.next().get(0));
            dArr[cluster] = dArr[cluster] + 1.0d;
        }
        for (int i2 = 0; i2 < this.states; i2++) {
            double d = dArr[i2];
            double size = this.sequnces.size();
            Double.isNaN(size);
            hiddenMarkovModel.setPi(i2, d / size);
        }
    }

    private void learnTransition(HiddenMarkovModel hiddenMarkovModel) {
        for (int i = 0; i < hiddenMarkovModel.getStateCount(); i++) {
            for (int i2 = 0; i2 < hiddenMarkovModel.getStateCount(); i2++) {
                hiddenMarkovModel.setTransitionProbability(i, i2, 0.0d);
            }
        }
        for (MLDataSet mLDataSet : this.sequnces.getSequences()) {
            if (mLDataSet.size() >= 2) {
                int cluster = this.clusters.cluster(mLDataSet.get(0));
                int i3 = 1;
                while (i3 < mLDataSet.size()) {
                    int cluster2 = this.clusters.cluster(mLDataSet.get(i3));
                    hiddenMarkovModel.setTransitionProbability(cluster, cluster2, hiddenMarkovModel.getTransitionProbability(cluster, cluster2) + 1.0d);
                    i3++;
                    cluster = cluster2;
                }
            }
        }
        for (int i4 = 0; i4 < hiddenMarkovModel.getStateCount(); i4++) {
            double d = 0.0d;
            for (int i5 = 0; i5 < hiddenMarkovModel.getStateCount(); i5++) {
                d += hiddenMarkovModel.getTransitionProbability(i4, i5);
            }
            if (d == 0.0d) {
                for (int i6 = 0; i6 < hiddenMarkovModel.getStateCount(); i6++) {
                    double stateCount = hiddenMarkovModel.getStateCount();
                    Double.isNaN(stateCount);
                    hiddenMarkovModel.setTransitionProbability(i4, i6, 1.0d / stateCount);
                }
            } else {
                for (int i7 = 0; i7 < hiddenMarkovModel.getStateCount(); i7++) {
                    hiddenMarkovModel.setTransitionProbability(i4, i7, hiddenMarkovModel.getTransitionProbability(i4, i7) / d);
                }
            }
        }
    }

    private boolean optimizeCluster(HiddenMarkovModel hiddenMarkovModel) {
        boolean z = false;
        for (MLDataSet mLDataSet : this.sequnces.getSequences()) {
            int[] stateSequence = new ViterbiCalculator(mLDataSet, hiddenMarkovModel).stateSequence();
            boolean z2 = z;
            for (int i = 0; i < stateSequence.length; i++) {
                MLDataPair mLDataPair = mLDataSet.get(i);
                if (this.clusters.cluster(mLDataPair) != stateSequence[i]) {
                    Clusters clusters = this.clusters;
                    clusters.remove(mLDataPair, clusters.cluster(mLDataPair));
                    this.clusters.put(mLDataPair, stateSequence[i]);
                    z2 = true;
                }
            }
            z = z2;
        }
        return !z;
    }

    @Override // org.encog.ml.train.MLTrain
    public void addStrategy(Strategy strategy) {
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    @Override // org.encog.ml.train.MLTrain
    public void finishTraining() {
    }

    @Override // org.encog.ml.train.MLTrain
    public double getError() {
        return this.done ? 0.0d : 100.0d;
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingImplementationType getImplementationType() {
        return TrainingImplementationType.Iterative;
    }

    @Override // org.encog.ml.train.MLTrain
    public int getIteration() {
        return this.iteration;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.method;
    }

    @Override // org.encog.ml.train.MLTrain
    public List<Strategy> getStrategies() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLDataSet getTraining() {
        return this.training;
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean isTrainingDone() {
        return this.done;
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration() {
        HiddenMarkovModel cloneStructure = this.modelHMM.cloneStructure();
        learnPi(cloneStructure);
        learnTransition(cloneStructure);
        learnOpdf(cloneStructure);
        this.done = optimizeCluster(cloneStructure);
        this.method = cloneStructure;
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration(int i) {
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    @Override // org.encog.ml.train.MLTrain
    public void setError(double d) {
    }

    @Override // org.encog.ml.train.MLTrain
    public void setIteration(int i) {
        this.iteration = i;
    }
}
