package org.encog.ml.svm.training;

import org.encog.EncogError;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.libsvm.svm;
import org.encog.mathutil.libsvm.svm_parameter;
import org.encog.mathutil.libsvm.svm_problem;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.svm.SVM;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.logging.EncogLogging;

/* loaded from: classes.dex */
public class SVMTrain extends BasicTraining {
    private double c;
    private int fold;
    private double gamma;
    private final SVM network;
    private svm_problem problem;
    private boolean trainingDone;

    public SVMTrain(SVM svm, MLDataSet mLDataSet) {
        super(TrainingImplementationType.OnePass);
        this.fold = 0;
        this.network = svm;
        setTraining(mLDataSet);
        this.trainingDone = false;
        this.problem = EncodeSVMProblem.encode(mLDataSet, 0);
        this.gamma = 1.0d / this.network.getInputCount();
        this.c = 1.0d;
    }

    private double evaluate(svm_parameter svm_parameterVar, svm_problem svm_problemVar, double[] dArr) {
        int i = 0;
        ErrorCalculation errorCalculation = new ErrorCalculation();
        if (svm_parameterVar.svm_type == 3 || svm_parameterVar.svm_type == 4) {
            for (int i2 = 0; i2 < svm_problemVar.l; i2++) {
                errorCalculation.updateError(dArr[i2], svm_problemVar.y[i2]);
            }
            return errorCalculation.calculate();
        }
        for (int i3 = 0; i3 < svm_problemVar.l; i3++) {
            if (dArr[i3] == svm_problemVar.y[i3]) {
                i++;
            }
        }
        return (100.0d * i) / svm_problemVar.l;
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    public double getC() {
        return this.c;
    }

    public int getFold() {
        return this.fold;
    }

    public double getGamma() {
        return this.gamma;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.network;
    }

    public svm_problem getProblem() {
        return this.problem;
    }

    @Override // org.encog.ml.train.BasicTraining, org.encog.ml.train.MLTrain
    public boolean isTrainingDone() {
        return this.trainingDone;
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration() {
        this.network.getParams().C = this.c;
        this.network.getParams().gamma = this.gamma;
        EncogLogging.log(1, "Training with parameters C = " + this.c + ", gamma = " + this.gamma);
        if (this.fold > 1) {
            double[] dArr = new double[this.problem.l];
            svm.svm_cross_validation(this.problem, this.network.getParams(), this.fold, dArr);
            this.network.setModel(null);
            setError(evaluate(this.network.getParams(), this.problem, dArr));
        } else {
            this.network.setModel(svm.svm_train(this.problem, this.network.getParams()));
            setError(this.network.calculateError(getTraining()));
        }
        this.trainingDone = true;
    }

    @Override // org.encog.ml.train.MLTrain
    public final TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    public void setC(double d) {
        this.c = d;
        if (this.c <= 0.0d || this.c < 1.0E-13d) {
            throw new EncogError("SVM training cannot use a c value less than zero.");
        }
    }

    public void setFold(int i) {
        this.fold = i;
    }

    public void setGamma(double d) {
        this.gamma = d;
        if (this.gamma <= 0.0d || this.gamma < 1.0E-13d) {
            throw new EncogError("SVM training cannot use a gamma value less than zero.");
        }
    }
}
