/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PathGradient;
import dr.inference.hmc.ReversibleHMCProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.AbstractAdaptableOperator;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.GeneralOperator;
import dr.inference.operators.PathDependent;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.math.MathUtils;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;

public class HamiltonianMonteCarloOperator
extends AbstractAdaptableOperator
implements GeneralOperator,
PathDependent,
ReversibleHMCProvider {
    final GradientWrtParameterProvider gradientProvider;
    protected double stepSize;
    LeapFrogEngine leapFrogEngine;
    protected final Parameter parameter;
    protected final MassPreconditioner preconditioning;
    private final Options runtimeOptions;
    protected final double[] mask;
    protected final Transform transform;
    private static final boolean DEBUG = false;

    public HamiltonianMonteCarloOperator(AdaptationMode adaptationMode, double d, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, Transform transform, Parameter parameter2, Options options, MassPreconditioner.Type type) {
        super(adaptationMode, options.targetAcceptanceProbability);
        this.setWeight(d);
        this.gradientProvider = gradientWrtParameterProvider;
        this.runtimeOptions = options;
        this.stepSize = options.initialStepSize;
        this.preconditioning = type.factory(gradientWrtParameterProvider, transform, options);
        this.parameter = parameter;
        this.mask = HamiltonianMonteCarloOperator.buildMask(parameter2);
        this.transform = transform;
        this.leapFrogEngine = this.constructLeapFrogEngine(transform);
    }

    protected LeapFrogEngine constructLeapFrogEngine(Transform transform) {
        return transform != null ? new LeapFrogEngine.WithTransform(this.parameter, transform, this.getDefaultInstabilityHandler(), this.preconditioning, this.mask) : new LeapFrogEngine.Default(this.parameter, this.getDefaultInstabilityHandler(), this.preconditioning, this.mask);
    }

    @Override
    public String getOperatorName() {
        return "VanillaHMC(" + this.parameter.getParameterName() + ")";
    }

    private boolean shouldUpdatePreconditioning() {
        return this.runtimeOptions.preconditioningUpdateFrequency > 0 && this.getCount() % (long)this.runtimeOptions.preconditioningUpdateFrequency == 0L && this.getCount() > (long)this.runtimeOptions.preconditioningDelay;
    }

    private static double[] buildMask(Parameter parameter) {
        if (parameter == null) {
            return null;
        }
        double[] dArray = new double[parameter.getDimension()];
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = parameter.getParameterValue(i) == 0.0 ? 0.0 : 1.0;
        }
        return dArray;
    }

    @Override
    public double doOperation() {
        throw new RuntimeException("Should not be executed");
    }

    @Override
    public double doOperation(Likelihood likelihood) {
        if (this.shouldCheckStepSize()) {
            this.checkStepSize();
        }
        if (this.shouldCheckGradient()) {
            this.checkGradient(likelihood);
        }
        if (this.shouldUpdatePreconditioning()) {
            this.preconditioning.storeSecant(new WrappedVector.Raw(this.leapFrogEngine.getLastGradient()), new WrappedVector.Raw(this.leapFrogEngine.getLastPosition()));
            this.preconditioning.updateMass();
        }
        try {
            return this.leapFrog();
        }
        catch (NumericInstabilityException numericInstabilityException) {
            return Double.NEGATIVE_INFINITY;
        }
    }

    @Override
    public void setPathParameter(double d) {
        if (this.gradientProvider instanceof PathGradient) {
            ((PathGradient)this.gradientProvider).setPathParameter(d);
        }
    }

    private boolean shouldCheckStepSize() {
        return this.getCount() < 1L && this.getMode() == AdaptationMode.ADAPTATION_ON;
    }

    private void checkStepSize() {
        double[] dArray = this.parameter.getParameterValues();
        boolean bl = false;
        for (int i = 0; !bl && i < this.runtimeOptions.checkStepSizeMaxIterations; ++i) {
            try {
                this.leapFrog();
                double d = this.gradientProvider.getLikelihood().getLogLikelihood();
                if (!Double.isNaN(d) && !Double.isInfinite(d)) {
                    bl = true;
                }
            }
            catch (Exception exception) {
                // empty catch block
            }
            if (!bl) {
                this.stepSize *= this.runtimeOptions.checkStepSizeReductionFactor;
            }
            ReadableVector.Utils.setParameter(dArray, this.parameter);
        }
        if (!bl) {
            throw new RuntimeException("Unable to find acceptable initial HMC step-size");
        }
    }

    boolean shouldCheckGradient() {
        return this.getCount() < (long)this.runtimeOptions.gradientCheckCount;
    }

    void checkGradient(final Likelihood likelihood) {
        if (this.parameter.getDimension() != this.gradientProvider.getDimension()) {
            throw new RuntimeException("Unequal dimensions");
        }
        MultivariateFunction multivariateFunction = new MultivariateFunction(){

            @Override
            public double evaluate(double[] dArray) {
                if (HamiltonianMonteCarloOperator.this.transform == null) {
                    ReadableVector.Utils.setParameter(dArray, HamiltonianMonteCarloOperator.this.parameter);
                    return likelihood.getLogLikelihood();
                }
                double[] dArray2 = HamiltonianMonteCarloOperator.this.transform.inverse(dArray, 0, dArray.length);
                ReadableVector.Utils.setParameter(dArray2, HamiltonianMonteCarloOperator.this.parameter);
                return likelihood.getLogLikelihood() - HamiltonianMonteCarloOperator.this.transform.getLogJacobian(dArray2, 0, dArray2.length);
            }

            @Override
            public int getNumArguments() {
                return HamiltonianMonteCarloOperator.this.parameter.getDimension();
            }

            @Override
            public double getLowerBound(int n) {
                return HamiltonianMonteCarloOperator.this.parameter.getBounds().getLowerLimit(n);
            }

            @Override
            public double getUpperBound(int n) {
                return HamiltonianMonteCarloOperator.this.parameter.getBounds().getUpperLimit(n);
            }
        };
        double[] dArray = this.gradientProvider.getGradientLogDensity();
        double[] dArray2 = this.parameter.getParameterValues();
        if (this.transform == null) {
            double[] dArray3 = NumericalDerivative.gradient(multivariateFunction, this.parameter.getParameterValues());
            if (!MathUtils.isClose(dArray, dArray3, this.runtimeOptions.gradientCheckTolerance)) {
                String string = "Gradients do not match:\n\tAnalytic: " + new WrappedVector.Raw(dArray) + "\n\tNumeric : " + new WrappedVector.Raw(dArray3) + "\n";
                throw new RuntimeException(string);
            }
        } else {
            double[] dArray4 = this.transform.transform(this.parameter.getParameterValues(), 0, this.parameter.getParameterValues().length);
            double[] dArray5 = NumericalDerivative.gradient(multivariateFunction, dArray4);
            double[] dArray6 = this.transform.updateGradientLogDensity(dArray, this.parameter.getParameterValues(), 0, this.parameter.getParameterValues().length);
            if (!MathUtils.isClose(dArray6, dArray5, this.runtimeOptions.gradientCheckTolerance)) {
                String string = "Transformed Gradients do not match:\n\tAnalytic: " + new WrappedVector.Raw(dArray6) + "\n\tNumeric : " + new WrappedVector.Raw(dArray5) + "\n\tParameter : " + new WrappedVector.Raw(this.parameter.getParameterValues()) + "\n\tTransformed Parameter : " + new WrappedVector.Raw(dArray4) + "\n";
                throw new RuntimeException(string);
            }
        }
        ReadableVector.Utils.setParameter(dArray2, this.parameter);
    }

    static double[] mask(double[] dArray, double[] dArray2) {
        assert (dArray2 == null || dArray2.length == dArray.length);
        if (dArray2 != null) {
            for (int i = 0; i < dArray.length; ++i) {
                int n = i;
                dArray[n] = dArray[n] * dArray2[i];
            }
        }
        return dArray;
    }

    static WrappedVector mask(WrappedVector wrappedVector, double[] dArray) {
        assert (dArray == null || dArray.length == wrappedVector.getDim());
        if (dArray != null) {
            for (int i = 0; i < wrappedVector.getDim(); ++i) {
                wrappedVector.set(i, wrappedVector.get(i) * dArray[i]);
            }
        }
        return wrappedVector;
    }

    private int getNumberOfSteps() {
        int n = this.runtimeOptions.nSteps;
        if (this.runtimeOptions.randomStepCountFraction > 0.0) {
            double d = (double)n * (1.0 + this.runtimeOptions.randomStepCountFraction * (MathUtils.nextDouble() - 0.5));
            n = Math.max(1, (int)d);
        }
        return n;
    }

    @Override
    public double getKineticEnergy(ReadableVector readableVector) {
        int n = readableVector.getDim();
        double d = 0.0;
        for (int i = 0; i < n; ++i) {
            d += readableVector.get(i) * this.preconditioning.getVelocity(i, readableVector);
        }
        return d / 2.0;
    }

    private double leapFrog() throws NumericInstabilityException {
        double[] dArray = this.leapFrogEngine.getInitialPosition();
        WrappedVector wrappedVector = HamiltonianMonteCarloOperator.mask(this.preconditioning.drawInitialMomentum(), this.mask);
        double d = this.getKineticEnergy(wrappedVector) + this.leapFrogEngine.getParameterLogJacobian();
        this.leapFrogEngine.updateMomentum(dArray, wrappedVector.getBuffer(), HamiltonianMonteCarloOperator.mask(this.gradientProvider.getGradientLogDensity(), this.mask), this.stepSize / 2.0);
        int n = this.getNumberOfSteps();
        for (int i = 0; i < n; ++i) {
            try {
                this.leapFrogEngine.updatePosition(dArray, wrappedVector, this.stepSize);
            }
            catch (ArithmeticException arithmeticException) {
                throw new NumericInstabilityException();
            }
            if (i >= n - 1) continue;
            try {
                this.leapFrogEngine.updateMomentum(dArray, wrappedVector.getBuffer(), HamiltonianMonteCarloOperator.mask(this.gradientProvider.getGradientLogDensity(), this.mask), this.stepSize);
                continue;
            }
            catch (ArithmeticException arithmeticException) {
                throw new NumericInstabilityException();
            }
        }
        this.leapFrogEngine.updateMomentum(dArray, wrappedVector.getBuffer(), HamiltonianMonteCarloOperator.mask(this.gradientProvider.getGradientLogDensity(), this.mask), this.stepSize / 2.0);
        double d2 = this.getKineticEnergy(wrappedVector) + this.leapFrogEngine.getParameterLogJacobian();
        return d - d2;
    }

    @Override
    protected double getAdaptableParameterValue() {
        return Math.log(this.stepSize);
    }

    @Override
    public void setAdaptableParameterValue(double d) {
        this.stepSize = Math.exp(d);
    }

    @Override
    public double getRawParameter() {
        return this.stepSize;
    }

    protected InstabilityHandler getDefaultInstabilityHandler() {
        return InstabilityHandler.REJECT;
    }

    @Override
    public String getAdaptableParameterName() {
        return "stepSize";
    }

    protected void doLeap(double[] dArray, WrappedVector wrappedVector, double d) throws NumericInstabilityException {
        this.leapFrogEngine.updateMomentum(dArray, wrappedVector.getBuffer(), HamiltonianMonteCarloOperator.mask(this.gradientProvider.getGradientLogDensity(), this.mask), d / 2.0);
        this.leapFrogEngine.updatePosition(dArray, wrappedVector, d);
        this.leapFrogEngine.updateMomentum(dArray, wrappedVector.getBuffer(), HamiltonianMonteCarloOperator.mask(this.gradientProvider.getGradientLogDensity(), this.mask), d / 2.0);
    }

    @Override
    public void reversiblePositionMomentumUpdate(WrappedVector wrappedVector, WrappedVector wrappedVector2, int n, double d) {
        try {
            this.doLeap(wrappedVector.getBuffer(), wrappedVector2, (double)n * d);
        }
        catch (NumericInstabilityException numericInstabilityException) {
            this.handleInstability();
        }
    }

    @Override
    public double[] getInitialPosition() {
        return this.leapFrogEngine.getInitialPosition();
    }

    @Override
    public double getParameterLogJacobian() {
        return this.leapFrogEngine.getParameterLogJacobian();
    }

    @Override
    public void setParameter(double[] dArray) {
        this.leapFrogEngine.setParameter(dArray);
    }

    @Override
    public WrappedVector drawMomentum() {
        return HamiltonianMonteCarloOperator.mask(this.preconditioning.drawInitialMomentum(), this.mask);
    }

    @Override
    public double getJointProbability(WrappedVector wrappedVector) {
        return this.gradientProvider.getLikelihood().getLogLikelihood() - this.getKineticEnergy(wrappedVector) - this.getParameterLogJacobian();
    }

    @Override
    public double getLogLikelihood() {
        return this.gradientProvider.getLikelihood().getLogLikelihood();
    }

    @Override
    public double getStepSize() {
        return this.stepSize;
    }

    protected void handleInstability() {
        throw new RuntimeException("Numerical instability; need to handle");
    }

    static interface LeapFrogEngine {
        public double[] getInitialPosition();

        public double getParameterLogJacobian();

        public void updateMomentum(double[] var1, double[] var2, double[] var3, double var4) throws NumericInstabilityException;

        public void updatePosition(double[] var1, WrappedVector var2, double var3) throws NumericInstabilityException;

        public void setParameter(double[] var1);

        public double[] getLastGradient();

        public double[] getLastPosition();

        public static class WithTransform
        extends Default {
            private final Transform transform;
            double[] unTransformedPosition;

            private WithTransform(Parameter parameter, Transform transform, InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, double[] dArray) {
                super(parameter, instabilityHandler, massPreconditioner, dArray);
                this.transform = transform;
            }

            @Override
            public double getParameterLogJacobian() {
                return this.transform.getLogJacobian(this.unTransformedPosition, 0, this.unTransformedPosition.length);
            }

            @Override
            public double[] getInitialPosition() {
                this.unTransformedPosition = super.getInitialPosition();
                return this.transform.transform(this.unTransformedPosition, 0, this.unTransformedPosition.length);
            }

            @Override
            public void updateMomentum(double[] dArray, double[] dArray2, double[] dArray3, double d) throws NumericInstabilityException {
                dArray3 = this.transform.updateGradientLogDensity(dArray3, this.unTransformedPosition, 0, this.unTransformedPosition.length);
                HamiltonianMonteCarloOperator.mask(dArray3, this.mask);
                super.updateMomentum(dArray, dArray2, dArray3, d);
            }

            @Override
            public void updatePosition(double[] dArray, WrappedVector wrappedVector, double d) throws NumericInstabilityException {
                super.updatePosition(dArray, wrappedVector, d);
                if (this.instabilityHandler.checkPositionTransform()) {
                    this.checkPosition(this.unTransformedPosition);
                }
            }

            @Override
            public void setParameter(double[] dArray) {
                this.unTransformedPosition = this.transform.inverse(dArray, 0, dArray.length);
                super.setParameter(this.unTransformedPosition);
            }

            private void checkPosition(double[] dArray) throws NumericInstabilityException {
                this.instabilityHandler.checkPosition(this.transform, dArray);
            }
        }

        public static class Default
        implements LeapFrogEngine {
            protected final Parameter parameter;
            final InstabilityHandler instabilityHandler;
            private final MassPreconditioner preconditioning;
            final double[] mask;
            double[] lastGradient;
            double[] lastPosition;

            Default(Parameter parameter, InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, double[] dArray) {
                this.parameter = parameter;
                this.instabilityHandler = instabilityHandler;
                this.preconditioning = massPreconditioner;
                this.mask = dArray;
            }

            @Override
            public double[] getInitialPosition() {
                return this.parameter.getParameterValues();
            }

            @Override
            public double getParameterLogJacobian() {
                return 0.0;
            }

            @Override
            public double[] getLastGradient() {
                return this.lastGradient;
            }

            @Override
            public double[] getLastPosition() {
                return this.lastPosition;
            }

            @Override
            public void updateMomentum(double[] dArray, double[] dArray2, double[] dArray3, double d) throws NumericInstabilityException {
                int n = dArray2.length;
                for (int i = 0; i < n; ++i) {
                    int n2 = i;
                    dArray2[n2] = dArray2[n2] + d * dArray3[i];
                    this.instabilityHandler.checkValue(dArray2[i]);
                }
                this.lastGradient = dArray3;
                this.lastPosition = dArray;
            }

            @Override
            public void updatePosition(double[] dArray, WrappedVector wrappedVector, double d) throws NumericInstabilityException {
                int n = wrappedVector.getDim();
                for (int i = 0; i < n; ++i) {
                    int n2 = i;
                    dArray[n2] = dArray[n2] + d * this.preconditioning.getVelocity(i, wrappedVector);
                    this.instabilityHandler.checkValue(dArray[i]);
                }
                this.setParameter(dArray);
            }

            @Override
            public void setParameter(double[] dArray) {
                ReadableVector.Utils.setParameter(dArray, this.parameter);
            }
        }
    }

    static enum InstabilityHandler {
        REJECT{

            @Override
            void checkValue(double d) throws NumericInstabilityException {
                if (Double.isNaN(d)) {
                    throw new NumericInstabilityException();
                }
            }

            @Override
            void checkPosition(Transform transform, double[] dArray) throws NumericInstabilityException {
                if (!transform.isInInteriorDomain(dArray, 0, dArray.length)) {
                    throw new NumericInstabilityException();
                }
            }

            @Override
            boolean checkPositionTransform() {
                return true;
            }
        }
        ,
        DEBUG{

            @Override
            void checkValue(double d) throws NumericInstabilityException {
                if (Double.isNaN(d)) {
                    System.err.println("Numerical instability in HMC momentum; throwing exception");
                    throw new NumericInstabilityException();
                }
            }

            @Override
            void checkPosition(Transform transform, double[] dArray) throws NumericInstabilityException {
                if (!transform.isInInteriorDomain(dArray, 0, dArray.length)) {
                    System.err.println("Numerical instability in HMC momentum; throwing exception");
                    throw new NumericInstabilityException();
                }
            }

            @Override
            boolean checkPositionTransform() {
                return true;
            }
        }
        ,
        IGNORE{

            @Override
            void checkValue(double d) {
            }

            @Override
            void checkPosition(Transform transform, double[] dArray) throws NumericInstabilityException {
            }

            @Override
            boolean checkPositionTransform() {
                return false;
            }
        };


        abstract void checkValue(double var1) throws NumericInstabilityException;

        abstract void checkPosition(Transform var1, double[] var2) throws NumericInstabilityException;

        abstract boolean checkPositionTransform();
    }

    static class NumericInstabilityException
    extends Exception {
        NumericInstabilityException() {
        }
    }

    public static class Options {
        final double initialStepSize;
        final int nSteps;
        final double randomStepCountFraction;
        final int preconditioningUpdateFrequency;
        final int preconditioningDelay;
        final int preconditioningMemory;
        final int gradientCheckCount;
        final double gradientCheckTolerance;
        final int checkStepSizeMaxIterations;
        final double checkStepSizeReductionFactor;
        final double targetAcceptanceProbability;

        public Options(double d, int n, double d2, int n2, int n3, int n4, int n5, double d3, int n6, double d4, double d5) {
            this.initialStepSize = d;
            this.nSteps = n;
            this.randomStepCountFraction = d2;
            this.preconditioningUpdateFrequency = n2;
            this.preconditioningDelay = n3;
            this.preconditioningMemory = n4;
            this.gradientCheckCount = n5;
            this.gradientCheckTolerance = d3;
            this.checkStepSizeMaxIterations = n6;
            this.checkStepSizeReductionFactor = d4;
            this.targetAcceptanceProbability = d5;
        }
    }
}

