/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.substmodel;

import dr.evomodel.substmodel.BaseSubstitutionModel;
import dr.evomodel.substmodel.DifferentiableSubstitutionModel;
import dr.evomodel.substmodel.DifferentialMassProvider;
import dr.evomodel.substmodel.EigenDecomposition;
import dr.math.matrixAlgebra.ReadableMatrix;
import dr.math.matrixAlgebra.WrappedMatrix;

public class DifferentiableSubstitutionModelUtil {
    public static double[] getDifferentialMassMatrix(double d, int n, WrappedMatrix wrappedMatrix, EigenDecomposition eigenDecomposition) {
        int n2;
        double[] dArray = eigenDecomposition.getEigenValues();
        WrappedMatrix.Raw raw = new WrappedMatrix.Raw(eigenDecomposition.getEigenVectors(), 0, n, n);
        WrappedMatrix.Raw raw2 = new WrappedMatrix.Raw(eigenDecomposition.getInverseEigenVectors(), 0, n, n);
        DifferentiableSubstitutionModelUtil.getTripleMatrixMultiplication(n, raw2, wrappedMatrix, raw);
        for (int i = 0; i < n; ++i) {
            for (n2 = 0; n2 < n; ++n2) {
                if (i == n2 || dArray[i] == dArray[n2]) {
                    wrappedMatrix.set(i, n2, wrappedMatrix.get(i, n2) * d);
                    continue;
                }
                wrappedMatrix.set(i, n2, wrappedMatrix.get(i, n2) * (1.0 - Math.exp((dArray[n2] - dArray[i]) * d)) / (dArray[i] - dArray[n2]));
            }
        }
        DifferentiableSubstitutionModelUtil.getTripleMatrixMultiplication(n, raw, wrappedMatrix, raw2);
        double[] dArray2 = new double[n * n];
        int n3 = n * n;
        for (n2 = 0; n2 < n3; ++n2) {
            dArray2[n2] = wrappedMatrix.get(n2);
        }
        return dArray2;
    }

    public static void getTripleMatrixMultiplication(int n, ReadableMatrix readableMatrix, WrappedMatrix wrappedMatrix, ReadableMatrix readableMatrix2) {
        int n2;
        int n3;
        double[][] dArray = new double[n][n];
        for (n3 = 0; n3 < n; ++n3) {
            for (n2 = 0; n2 < n; ++n2) {
                for (int i = 0; i < n; ++i) {
                    double[] dArray2 = dArray[n3];
                    int n4 = n2;
                    dArray2[n4] = dArray2[n4] + wrappedMatrix.get(n3, i) * readableMatrix2.get(i, n2);
                }
            }
        }
        for (n3 = 0; n3 < n; ++n3) {
            for (n2 = 0; n2 < n; ++n2) {
                double d = 0.0;
                for (int i = 0; i < n; ++i) {
                    d += readableMatrix.get(n3, i) * dArray[i][n2];
                }
                wrappedMatrix.set(n3, n2, d);
            }
        }
    }

    public static WrappedMatrix getInfinitesimalDifferentialMatrix(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, BaseSubstitutionModel baseSubstitutionModel) {
        if (!(baseSubstitutionModel instanceof DifferentiableSubstitutionModel)) {
            throw new RuntimeException("Not supported!");
        }
        double d = baseSubstitutionModel.setupMatrix();
        int n = baseSubstitutionModel.getDataType().getStateCount();
        int n2 = baseSubstitutionModel.getRateCount(n);
        double[] dArray = new double[n * n];
        baseSubstitutionModel.getInfinitesimalMatrix(dArray);
        double[] dArray2 = new double[n2];
        ((DifferentiableSubstitutionModel)((Object)baseSubstitutionModel)).setupDifferentialRates(wrtParameter, dArray2, d);
        double[][] dArray3 = new double[n][n];
        baseSubstitutionModel.setupQMatrix(dArray2, baseSubstitutionModel.getFrequencyModel().getFrequencies(), dArray3);
        baseSubstitutionModel.makeValid(dArray3, n);
        double d2 = ((DifferentiableSubstitutionModel)((Object)baseSubstitutionModel)).getWeightedNormalizationGradient(wrtParameter, dArray3, baseSubstitutionModel.getFrequencyModel().getFrequencies());
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                double[] dArray4 = dArray3[i];
                int n3 = j;
                dArray4[n3] = dArray4[n3] - dArray[i * n + j] * d2;
            }
        }
        return new WrappedMatrix.ArrayOfArray(dArray3);
    }
}

