package org.encog.neural.networks.training.lma;

import org.encog.mathutil.error.ErrorCalculation;
import org.encog.mathutil.matrices.decomposition.LUDecomposition;
import org.encog.mathutil.matrices.hessian.ComputeHessian;
import org.encog.mathutil.matrices.hessian.HessianCR;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.structure.NetworkCODEC;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.concurrency.MultiThreadable;
import org.encog.util.validate.ValidateNetwork;

/* loaded from: classes.dex */
public class LevenbergMarquardtTraining extends BasicTraining implements MultiThreadable {
    public static final double LAMBDA_MAX = 1.0E25d;
    public static final double SCALE_LAMBDA = 10.0d;
    private double[] deltas;
    private final double[] diagonal;
    private ComputeHessian hessian;
    private final MLDataSet indexableTraining;
    private boolean initComplete;
    private double lambda;
    private final BasicNetwork network;
    private final MLDataPair pair;
    private final int trainingLength;
    private final int weightCount;
    private double[] weights;

    public LevenbergMarquardtTraining(BasicNetwork basicNetwork, MLDataSet mLDataSet) {
        this(basicNetwork, mLDataSet, new HessianCR());
    }

    public LevenbergMarquardtTraining(BasicNetwork basicNetwork, MLDataSet mLDataSet, ComputeHessian computeHessian) {
        super(TrainingImplementationType.Iterative);
        ValidateNetwork.validateMethodToData(basicNetwork, mLDataSet);
        setTraining(mLDataSet);
        this.indexableTraining = getTraining();
        this.network = basicNetwork;
        this.trainingLength = (int) this.indexableTraining.getRecordCount();
        this.weightCount = this.network.getStructure().calculateSize();
        this.lambda = 0.1d;
        int i = this.weightCount;
        this.deltas = new double[i];
        this.diagonal = new double[i];
        this.pair = new BasicMLDataPair(new BasicMLData(this.indexableTraining.getInputSize()), new BasicMLData(this.indexableTraining.getIdealSize()));
        this.hessian = computeHessian;
    }

    private void applyLambda() {
        double[][] hessian = this.hessian.getHessian();
        for (int i = 0; i < this.weightCount; i++) {
            hessian[i][i] = this.diagonal[i] + this.lambda;
        }
    }

    private double calculateError() {
        ErrorCalculation errorCalculation = new ErrorCalculation();
        for (int i = 0; i < this.trainingLength; i++) {
            this.indexableTraining.getRecord(i, this.pair);
            errorCalculation.updateError(this.network.compute(this.pair.getInput()).getData(), this.pair.getIdeal().getData(), this.pair.getSignificance());
        }
        return errorCalculation.calculateESS();
    }

    private void saveDiagonal() {
        double[][] hessian = this.hessian.getHessian();
        for (int i = 0; i < this.weightCount; i++) {
            this.diagonal[i] = hessian[i][i];
        }
    }

    @Override // org.encog.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    public ComputeHessian getHessian() {
        return this.hessian;
    }

    @Override // org.encog.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.network;
    }

    @Override // org.encog.util.concurrency.MultiThreadable
    public int getThreadCount() {
        ComputeHessian computeHessian = this.hessian;
        if (computeHessian instanceof MultiThreadable) {
            return ((MultiThreadable) computeHessian).getThreadCount();
        }
        return 1;
    }

    @Override // org.encog.ml.train.MLTrain
    public void iteration() {
        if (!this.initComplete) {
            this.hessian.init(this.network, getTraining());
            this.initComplete = true;
        }
        preIteration();
        this.hessian.clear();
        this.weights = NetworkCODEC.networkToArray(this.network);
        this.hessian.compute();
        double sse = this.hessian.getSSE();
        saveDiagonal();
        boolean z = false;
        double d = sse;
        while (!z) {
            applyLambda();
            LUDecomposition lUDecomposition = new LUDecomposition(this.hessian.getHessianMatrix());
            boolean isNonsingular = lUDecomposition.isNonsingular();
            if (isNonsingular) {
                this.deltas = lUDecomposition.Solve(this.hessian.getGradients());
                updateWeights();
                d = calculateError();
            }
            if (!isNonsingular || d >= sse) {
                this.lambda *= 10.0d;
                if (this.lambda > 1.0E25d) {
                    this.lambda = 1.0E25d;
                    z = true;
                }
            } else {
                this.lambda /= 10.0d;
                z = true;
            }
        }
        setError(d);
        postIteration();
    }

    @Override // org.encog.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encog.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    @Override // org.encog.util.concurrency.MultiThreadable
    public void setThreadCount(int i) {
        ComputeHessian computeHessian = this.hessian;
        if (computeHessian instanceof MultiThreadable) {
            ((MultiThreadable) computeHessian).setThreadCount(i);
            return;
        }
        if (i == 1 || i == 0) {
            return;
        }
        throw new TrainingError("The Hessian object in use(" + this.hessian.getClass().toString() + ") does not support multi-threaded mode.");
    }

    public void updateWeights() {
        double[] dArr = (double[]) this.weights.clone();
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = dArr[i] + this.deltas[i];
        }
        NetworkCODEC.arrayToNetwork(dArr, this.network);
    }
}
