/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.treedatalikelihood.continuous.ContinuousRateTransformation;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.RepeatedMeasuresTraitDataModel;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.inference.model.CompoundParameter;
import dr.inference.model.MatrixParameterInterface;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class TreeScaledRepeatedMeasuresTraitDataModel
extends RepeatedMeasuresTraitDataModel {
    private Tree treeModel;
    private ContinuousRateTransformation rateTransformation;
    private static final boolean DEBUG = false;

    public TreeScaledRepeatedMeasuresTraitDataModel(String string, ContinuousTraitPartialsProvider continuousTraitPartialsProvider, CompoundParameter compoundParameter, boolean[] blArray, boolean bl, int n, int n2, MatrixParameterInterface matrixParameterInterface, PrecisionType precisionType) {
        super(string, continuousTraitPartialsProvider, compoundParameter, blArray, bl, n, n2, matrixParameterInterface, precisionType);
        if (!(continuousTraitPartialsProvider instanceof ContinuousTraitDataModel)) {
            throw new RuntimeException("not yet implemented for alternative child models. (can't just scale the partial in super.getTipPartial)");
        }
    }

    @Override
    public void addTreeAndRateModel(Tree tree, ContinuousRateTransformation continuousRateTransformation) {
        this.treeModel = tree;
        this.rateTransformation = continuousRateTransformation;
    }

    @Override
    public double[] getTipPartial(int n, boolean bl) {
        double[] dArray = super.getTipPartial(n, bl);
        double d = this.getTipHeight(n);
        this.scalePartial(d, dArray);
        return dArray;
    }

    private double getTipHeight(int n) {
        double d = this.treeModel.getNodeHeight(this.treeModel.getRoot()) - this.treeModel.getNodeHeight(this.treeModel.getExternalNode(n));
        return d * this.rateTransformation.getNormalization();
    }

    private void scalePartial(double d, double[] dArray) {
        this.scaleArray(1.0 / d, dArray, this.dimTrait, this.dimTrait * this.dimTrait);
        this.scaleArray(d, dArray, this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait * this.dimTrait);
    }

    private void scaleArray(double d, double[] dArray, int n, int n2) {
        int n3 = n;
        while (n3 < n + n2) {
            int n4 = n3++;
            dArray[n4] = dArray[n4] * d;
        }
    }

    @Override
    public void chainRuleWrtVariance(double[] dArray, NodeRef nodeRef) {
        double d = this.getTipHeight(nodeRef.getNumber());
        this.scaleArray(d, dArray, 0, dArray.length);
    }

    @Override
    public DenseMatrix64F getExtensionVariance(NodeRef nodeRef) {
        DenseMatrix64F denseMatrix64F = this.getExtensionVariance();
        double d = this.getTipHeight(nodeRef.getNumber());
        CommonOps.scale(d, denseMatrix64F);
        return denseMatrix64F;
    }

    @Override
    public void getMeanTipVariances(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
        double d = 0.0;
        for (int i = 0; i < this.treeModel.getExternalNodeCount(); ++i) {
            d += this.getTipHeight(this.treeModel.getExternalNode(i).getNumber());
        }
        CommonOps.scale(d /= (double)this.treeModel.getExternalNodeCount(), denseMatrix64F, denseMatrix64F2);
    }
}

