/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.treedatalikelihood.continuous.ContinuousRateTransformation;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics;
import dr.inference.model.AbstractModel;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.List;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class JointPartialsProvider
extends AbstractModel
implements ContinuousTraitPartialsProvider {
    private final ContinuousTraitPartialsProvider[] providers;
    private final int traitDim;
    private final int dataDim;
    private final List<Integer> missingIndices;
    private final boolean[] missingDataIndicators;
    private final boolean[] missingTraitIndicators;
    private final boolean defaultAllowSingular;
    private final Boolean computeDeterminant;
    private final PrecisionType precisionType;
    private String tipTraitName;
    private final CompoundParameter jointDataParameter;
    private static final Boolean DEBUG = false;
    public static final AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private static final String PARSER_NAME = "jointPartialsProvider";

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            int n;
            List<ContinuousTraitPartialsProvider> list = xMLObject.getAllChildren(ContinuousTraitPartialsProvider.class);
            ContinuousTraitPartialsProvider[] continuousTraitPartialsProviderArray = new ContinuousTraitPartialsProvider[list.size()];
            for (n = 0; n < list.size(); ++n) {
                continuousTraitPartialsProviderArray[n] = list.get(n);
            }
            n = continuousTraitPartialsProviderArray[0].getTraitCount();
            for (int i = 1; i < continuousTraitPartialsProviderArray.length; ++i) {
                if (continuousTraitPartialsProviderArray[i].getTraitCount() == n) continue;
                throw new XMLParseException("all partials providers must have the same trait count");
            }
            PrecisionType precisionType = continuousTraitPartialsProviderArray[0].getPrecisionType();
            for (int i = 1; i < continuousTraitPartialsProviderArray.length; ++i) {
                if (continuousTraitPartialsProviderArray[i].getPrecisionType() == precisionType) continue;
                throw new XMLParseException("all partials providers must have the same precision type. Provider for model " + continuousTraitPartialsProviderArray[0].getModelName() + " has precision type '" + (Object)((Object)precisionType) + "', while provider for model " + continuousTraitPartialsProviderArray[i].getModelName() + " has precision type '" + (Object)((Object)continuousTraitPartialsProviderArray[i].getPrecisionType()) + "'.");
            }
            return new JointPartialsProvider(PARSER_NAME, continuousTraitPartialsProviderArray, precisionType);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return new XMLSyntaxRule[]{new ElementRule(ContinuousTraitPartialsProvider.class, 0, Integer.MAX_VALUE)};
        }

        @Override
        public String getParserDescription() {
            return "Merges two Gaussian processes.";
        }

        @Override
        public Class getReturnType() {
            return JointPartialsProvider.class;
        }

        @Override
        public String getParserName() {
            return PARSER_NAME;
        }
    };

    public JointPartialsProvider(String string, ContinuousTraitPartialsProvider[] continuousTraitPartialsProviderArray, PrecisionType precisionType) {
        super(string);
        int n;
        this.providers = continuousTraitPartialsProviderArray;
        this.precisionType = precisionType;
        int n2 = 0;
        int n3 = 0;
        for (ContinuousTraitPartialsProvider continuousTraitPartialsProvider : continuousTraitPartialsProviderArray) {
            n2 += continuousTraitPartialsProvider.getTraitDimension();
            n3 += continuousTraitPartialsProvider.getDataDimension();
        }
        this.traitDim = n2;
        this.dataDim = n3;
        boolean[][] blArray = new boolean[continuousTraitPartialsProviderArray.length][0];
        boolean[][] blArray2 = new boolean[continuousTraitPartialsProviderArray.length][0];
        int[] nArray = new int[continuousTraitPartialsProviderArray.length];
        int[] nArray2 = new int[continuousTraitPartialsProviderArray.length];
        for (n = 0; n < continuousTraitPartialsProviderArray.length; ++n) {
            blArray[n] = continuousTraitPartialsProviderArray[n].getTraitMissingIndicators();
            blArray2[n] = continuousTraitPartialsProviderArray[n].getDataMissingIndicators();
            nArray2[n] = continuousTraitPartialsProviderArray[n].getDataDimension();
            nArray[n] = continuousTraitPartialsProviderArray[n].getTraitDimension();
        }
        n = continuousTraitPartialsProviderArray[0].getParameter().getParameterCount();
        this.missingDataIndicators = this.mergeIndicators(blArray2, nArray2, n, n3);
        this.missingTraitIndicators = this.mergeIndicators(blArray, nArray, n, n2);
        this.missingIndices = ContinuousTraitPartialsProvider.indicatorToIndices(this.missingDataIndicators);
        this.defaultAllowSingular = this.setDefaultAllowSingular();
        this.computeDeterminant = this.defaultAllowSingular;
        for (ContinuousTraitPartialsProvider continuousTraitPartialsProvider : continuousTraitPartialsProviderArray) {
            if (!(continuousTraitPartialsProvider instanceof Model)) continue;
            this.addModel((Model)((Object)continuousTraitPartialsProvider));
        }
        Object[] objectArray = new CompoundParameter[continuousTraitPartialsProviderArray.length];
        for (int i = 0; i < objectArray.length; ++i) {
            objectArray[i] = continuousTraitPartialsProviderArray[i].getParameter();
        }
        this.jointDataParameter = CompoundParameter.mergeParameters((CompoundParameter[])objectArray);
        if (DEBUG.booleanValue()) {
            CompoundParameter.checkParametersMerged(this.jointDataParameter, (CompoundParameter[])objectArray);
        }
    }

    private boolean[] mergeIndicators(boolean[][] blArray, int[] nArray, int n, int n2) {
        boolean[] blArray2 = new boolean[n2 * n];
        for (int i = 0; i < n; ++i) {
            int n3 = i * n2;
            for (int j = 0; j < this.providers.length; ++j) {
                int n4 = nArray[j];
                int n5 = i * n4;
                System.arraycopy(blArray[j], n5, blArray2, n3, n4);
                n3 += n4;
            }
        }
        return blArray2;
    }

    @Override
    public boolean[] getTraitMissingIndicators() {
        return this.missingTraitIndicators;
    }

    @Override
    public boolean bufferTips() {
        return true;
    }

    @Override
    public int getTraitCount() {
        return this.providers[0].getTraitCount();
    }

    @Override
    public int getTraitDimension() {
        return this.traitDim;
    }

    @Override
    public String getTipTraitName() {
        return this.tipTraitName;
    }

    @Override
    public void setTipTraitName(String string) {
        this.tipTraitName = string;
        for (int i = 0; i < this.providers.length; ++i) {
            this.providers[i].setTipTraitName(string + "." + i);
        }
    }

    @Override
    public int getDataDimension() {
        return this.dataDim;
    }

    @Override
    public int[] getPartitionDimensions() {
        int[] nArray = new int[this.providers.length];
        for (int i = 0; i < this.providers.length; ++i) {
            nArray[i] = this.providers[i].getTraitDimension();
        }
        return nArray;
    }

    @Override
    public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] dArray) {
        return dArray;
    }

    @Override
    public PrecisionType getPrecisionType() {
        return this.precisionType;
    }

    @Override
    public double[] getTipPartial(int n, boolean bl) {
        if (this.precisionType != PrecisionType.FULL) {
            throw new RuntimeException("Currently only implemented for full precision");
        }
        if (bl) {
            throw new RuntimeException("Wishart statistics currently not implemented for joint partials provider");
        }
        double[] dArray = new double[this.precisionType.getPartialsDimension(this.traitDim)];
        int n2 = this.precisionType.getMeanOffset(this.traitDim);
        int n3 = this.precisionType.getPrecisionOffset(this.traitDim);
        int n4 = this.precisionType.getVarianceOffset(this.traitDim);
        int n5 = this.precisionType.getEffectiveDimensionOffset(this.traitDim);
        int n6 = this.precisionType.getDeterminantOffset(this.traitDim);
        int n7 = this.precisionType.getRemainderOffset(this.traitDim);
        WrappedMatrix.Indexed indexed = WrappedMatrix.Utils.wrapBlockDiagonalMatrix(dArray, n3, 0, this.traitDim);
        WrappedMatrix.Indexed indexed2 = WrappedMatrix.Utils.wrapBlockDiagonalMatrix(dArray, n4, 0, this.traitDim);
        int n8 = 0;
        for (ContinuousTraitPartialsProvider continuousTraitPartialsProvider : this.providers) {
            double[] dArray2 = continuousTraitPartialsProvider.getTipPartial(n, bl);
            int n9 = continuousTraitPartialsProvider.getTraitDimension();
            int n10 = this.precisionType.getPrecisionOffset(n9);
            WrappedMatrix.Raw raw = new WrappedMatrix.Raw(dArray2, n10, n9, n9);
            WrappedMatrix.Utils.transferSymmetricBlockDiagonal(raw, indexed, n8);
            WrappedMatrix.Raw raw2 = new WrappedMatrix.Raw(dArray2, this.precisionType.getVarianceOffset(n9), n9, n9);
            WrappedMatrix.Utils.transferSymmetricBlockDiagonal(raw2, indexed2, n8);
            n8 += n9;
            System.arraycopy(dArray2, this.precisionType.getMeanOffset(n9), dArray, n2, n9);
            n2 += n9;
            if (this.precisionType.hasEffectiveDimension()) {
                int n11 = n5;
                dArray[n11] = dArray[n11] + dArray2[this.precisionType.getEffectiveDimensionOffset(n9)];
            }
            if (this.precisionType.hasEffectiveDimension() && this.computeDeterminant.booleanValue()) {
                double d = dArray2[this.precisionType.getDeterminantOffset(n9)];
                if (!this.precisionType.isMissingDeterminantValue(d)) {
                    // empty if block
                }
                int n12 = n6;
                dArray[n12] = dArray[n12] + d;
            }
            if (!this.precisionType.hasRemainder()) continue;
            int n13 = n7;
            dArray[n13] = dArray[n13] + dArray2[this.precisionType.getRemainderOffset(n9)];
        }
        if (!this.computeDeterminant.booleanValue()) {
            this.precisionType.fillNoDeterminantInPartials(dArray, 0, this.traitDim);
        }
        return dArray;
    }

    @Override
    public List<Integer> getMissingIndices() {
        return this.missingIndices;
    }

    @Override
    public boolean[] getDataMissingIndicators() {
        return this.missingDataIndicators;
    }

    @Override
    public CompoundParameter getParameter() {
        return this.jointDataParameter;
    }

    @Override
    public boolean usesMissingIndices() {
        boolean bl = false;
        for (ContinuousTraitPartialsProvider continuousTraitPartialsProvider : this.providers) {
            bl = bl || continuousTraitPartialsProvider.usesMissingIndices();
        }
        return bl;
    }

    @Override
    public ContinuousTraitPartialsProvider[] getChildModels() {
        return this.providers;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        this.fireModelChanged();
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public boolean getDefaultAllowSingular() {
        return this.defaultAllowSingular;
    }

    private boolean setDefaultAllowSingular() {
        boolean bl = false;
        for (ContinuousTraitPartialsProvider continuousTraitPartialsProvider : this.providers) {
            bl = bl || continuousTraitPartialsProvider.getDefaultAllowSingular();
        }
        return bl;
    }

    @Override
    public boolean suppliesWishartStatistics() {
        boolean bl = true;
        for (ContinuousTraitPartialsProvider continuousTraitPartialsProvider : this.providers) {
            bl = bl && continuousTraitPartialsProvider.suppliesWishartStatistics();
        }
        return bl;
    }

    @Override
    public void addTreeAndRateModel(Tree tree, ContinuousRateTransformation continuousRateTransformation) {
        for (ContinuousTraitPartialsProvider continuousTraitPartialsProvider : this.providers) {
            continuousTraitPartialsProvider.addTreeAndRateModel(tree, continuousRateTransformation);
        }
    }

    @Override
    public WrappedNormalSufficientStatistics partitionNormalStatistics(WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics, ContinuousTraitPartialsProvider continuousTraitPartialsProvider) {
        int n = 0;
        for (ContinuousTraitPartialsProvider object2 : this.providers) {
            if (continuousTraitPartialsProvider == object2) break;
            n += object2.getTraitDimension();
        }
        int n2 = continuousTraitPartialsProvider.getTraitDimension();
        WrappedVector wrappedVector = wrappedNormalSufficientStatistics.getMean();
        WrappedVector.View view = new WrappedVector.View(wrappedVector, n, n2);
        int[] nArray = new int[n2];
        for (int wrappedMatrix = 0; wrappedMatrix < n2; ++wrappedMatrix) {
            nArray[wrappedMatrix] = wrappedMatrix + n;
        }
        WrappedMatrix wrappedMatrix = wrappedNormalSufficientStatistics.getVariance();
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(n2, n2);
        for (int denseMatrix64F2 = 0; denseMatrix64F2 < n2; ++denseMatrix64F2) {
            for (int i = 0; i < n2; ++i) {
                denseMatrix64F.set(denseMatrix64F2, i, wrappedMatrix.get(nArray[denseMatrix64F2], nArray[i]));
            }
        }
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(n2, n2);
        CommonOps.invert(denseMatrix64F, denseMatrix64F2);
        return new WrappedNormalSufficientStatistics(view, new WrappedMatrix.WrappedDenseMatrix(denseMatrix64F2), new WrappedMatrix.WrappedDenseMatrix(denseMatrix64F));
    }

    @Override
    public ContinuousTraitPartialsProvider getProviderForTrait(String string) {
        if (string.equals(this.getTipTraitName())) {
            return this;
        }
        for (ContinuousTraitPartialsProvider continuousTraitPartialsProvider : this.providers) {
            System.out.println(continuousTraitPartialsProvider.getTipTraitName());
            if (!string.equals(continuousTraitPartialsProvider.getTipTraitName())) continue;
            return continuousTraitPartialsProvider;
        }
        throw new RuntimeException("Partials provider does not have trait '" + string + "', nor did any of its sub-models");
    }

    @Override
    public void updateTipDataGradient(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, NodeRef nodeRef, int n, int n2) {
        ContinuousTraitPartialsProvider continuousTraitPartialsProvider = this.getProviderOffset(n, n2);
        continuousTraitPartialsProvider.updateTipDataGradient(denseMatrix64F, denseMatrix64F2, nodeRef, 0, n2);
    }

    private ContinuousTraitPartialsProvider getProviderOffset(int n, int n2) {
        int n3 = 0;
        int n4 = 0;
        while (n3 < n) {
            n3 += this.providers[n4].getTraitDimension();
            ++n4;
        }
        ContinuousTraitPartialsProvider continuousTraitPartialsProvider = this.providers[n4];
        if (n3 != n || continuousTraitPartialsProvider.getTraitDimension() != n2) {
            throw new RuntimeException("Offset and dimension must perfectly align with a child model (for now)");
        }
        return continuousTraitPartialsProvider;
    }

    @Override
    public boolean needToUpdateTipDataGradient(int n, int n2) {
        ContinuousTraitPartialsProvider continuousTraitPartialsProvider = this.getProviderOffset(n, n2);
        return continuousTraitPartialsProvider.needToUpdateTipDataGradient(n, n2);
    }
}

