package org.encog.ensemble.adaboost;

import java.util.ArrayList;
import org.encog.ensemble.Ensemble;
import org.encog.ensemble.EnsembleAggregator;
import org.encog.ensemble.EnsembleML;
import org.encog.ensemble.EnsembleMLMethodFactory;
import org.encog.ensemble.EnsembleTrainFactory;
import org.encog.ensemble.EnsembleTypes;
import org.encog.ensemble.GenericEnsembleML;
import org.encog.ensemble.data.EnsembleDataSet;
import org.encog.ensemble.data.factories.ResamplingDataSetFactory;
import org.encog.mathutil.VectorAlgebra;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;

/* loaded from: classes.dex */
public class AdaBoost extends Ensemble {
    private int T;
    private VectorAlgebra va;
    private ArrayList<Double> weights;

    public AdaBoost(int i, int i2, EnsembleMLMethodFactory ensembleMLMethodFactory, EnsembleTrainFactory ensembleTrainFactory, EnsembleAggregator ensembleAggregator) {
        this.dataSetFactory = new ResamplingDataSetFactory(i2);
        this.T = i;
        this.mlFactory = ensembleMLMethodFactory;
        this.va = new VectorAlgebra();
        this.weights = new ArrayList<>();
        this.members = new ArrayList<>();
        this.trainFactory = ensembleTrainFactory;
        this.aggregator = ensembleAggregator;
    }

    private double epsilon(GenericEnsembleML genericEnsembleML, MLDataSet mLDataSet) {
        int i = 0;
        for (MLDataPair mLDataPair : mLDataSet) {
            if (genericEnsembleML.classify(mLDataPair.getInput()) != genericEnsembleML.winner(mLDataPair.getIdeal())) {
                i++;
            }
        }
        return i / mLDataSet.size();
    }

    private double getWeightedError(GenericEnsembleML genericEnsembleML, MLDataSet mLDataSet) {
        double d = 0.0d;
        for (int i = 0; i < mLDataSet.size(); i++) {
            MLDataPair mLDataPair = mLDataSet.get(i);
            if (genericEnsembleML.classify(mLDataPair.getInput()) != genericEnsembleML.winner(mLDataPair.getIdeal())) {
                d += mLDataPair.getSignificance();
            }
        }
        return d;
    }

    private ArrayList<Double> updateD(GenericEnsembleML genericEnsembleML, MLDataSet mLDataSet, ArrayList<Double> arrayList) {
        ArrayList<Double> arrayList2 = new ArrayList<>();
        double epsilon = epsilon(genericEnsembleML, mLDataSet);
        double log = Math.log(1.0d - (epsilon / epsilon));
        for (int i = 0; i < mLDataSet.size(); i++) {
            arrayList2.add(Double.valueOf(arrayList.get(i).doubleValue() * Math.exp((-log) * this.va.dotProduct(mLDataSet.get(i).getIdeal().getData(), genericEnsembleML.compute(mLDataSet.get(i).getInput()).getData()))));
        }
        return arrayList2;
    }

    @Override // org.encog.ensemble.Ensemble
    public void addMember(EnsembleML ensembleML) throws Ensemble.NotPossibleInThisMethod {
        throw new Ensemble.NotPossibleInThisMethod();
    }

    @Override // org.encog.ensemble.Ensemble
    public EnsembleTypes.ProblemType getProblemType() {
        return EnsembleTypes.ProblemType.CLASSIFICATION;
    }

    @Override // org.encog.ensemble.Ensemble
    public void initMembers() {
    }

    @Override // org.encog.ensemble.Ensemble
    public void train(double d, double d2, EnsembleDataSet ensembleDataSet, boolean z) {
        ArrayList<Double> arrayList = new ArrayList<>();
        int size = this.dataSetFactory.getInputData().size();
        for (int i = 0; i < size; i++) {
            arrayList.add(Double.valueOf(1.0d / size));
        }
        for (int i2 = 0; i2 < this.T; i2++) {
            this.dataSetFactory.setSignificance(arrayList);
            EnsembleDataSet newDataSet = this.dataSetFactory.getNewDataSet();
            GenericEnsembleML genericEnsembleML = new GenericEnsembleML(this.mlFactory.createML(this.dataSetFactory.getInputData().getInputSize(), this.dataSetFactory.getInputData().getIdealSize()), this.mlFactory.getLabel());
            do {
                this.mlFactory.reInit(genericEnsembleML.getMl());
                genericEnsembleML.setTraining(this.trainFactory.getTraining(genericEnsembleML.getMl(), newDataSet));
                genericEnsembleML.train(d, z);
            } while (genericEnsembleML.getError(ensembleDataSet) > d2);
            double weightedError = getWeightedError(genericEnsembleML, newDataSet);
            this.members.add(genericEnsembleML);
            this.weights.add(Double.valueOf(weightedError));
            arrayList = updateD(genericEnsembleML, newDataSet, arrayList);
        }
    }
}
