/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.substmodel;

import dr.evolution.datatype.DataType;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.substmodel.ComplexSubstitutionModel;
import dr.evomodel.substmodel.DifferentiableSubstitutionModel;
import dr.evomodel.substmodel.DifferentialMassProvider;
import dr.evomodel.substmodel.EigenDecomposition;
import dr.evomodel.substmodel.EigenSystem;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.MarkovModulatedFrequencyModel;
import dr.evomodel.substmodel.ParameterReplaceableSubstitutionModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;

public class MarkovModulatedSubstitutionModel
extends ComplexSubstitutionModel
implements ParameterReplaceableSubstitutionModel,
DifferentiableSubstitutionModel,
Citable,
Loggable {
    private List<SubstitutionModel> baseModels;
    private final int numBaseModel;
    private final int baseStateCount;
    private final Parameter switchingRates;
    private static final boolean IGNORE_RATES = false;
    private static final boolean DEBUG = false;
    private static final boolean NEW_STORE_RESTORE = true;
    private final double[] baseMatrix;
    private Parameter rateScalar;
    private boolean birthDeathModel;
    private boolean geometricRates;
    private final Parameter relativeWeights;
    private final SiteRateModel gammaRateModel;
    private EigenDecomposition storedEigenDecomposition;
    private boolean storedUpdateMatrix;

    public MarkovModulatedSubstitutionModel(String string, List<SubstitutionModel> list, Parameter parameter, DataType dataType, EigenSystem eigenSystem) {
        this(string, list, parameter, dataType, eigenSystem, null, false, null, null);
    }

    public MarkovModulatedSubstitutionModel(String string, List<SubstitutionModel> list, Parameter parameter, DataType dataType, EigenSystem eigenSystem, Parameter parameter2, boolean bl, SiteRateModel siteRateModel, Parameter parameter3) {
        super(string, dataType, (FrequencyModel)null, (Parameter)null);
        this.baseModels = list;
        this.numBaseModel = list.size();
        if (this.numBaseModel == 0) {
            throw new RuntimeException("May not construct MarkovModulatedSubstitutionModel with 0 base models");
        }
        this.switchingRates = parameter;
        this.addVariable(parameter);
        this.relativeWeights = parameter3;
        if (parameter.getDimension() != 2 * (this.numBaseModel - 1) && parameter.getDimension() != this.numBaseModel * (this.numBaseModel - 1)) {
            throw new RuntimeException("Wrong switching rate dimensions");
        }
        ArrayList<FrequencyModel> arrayList = new ArrayList<FrequencyModel>();
        int n = 0;
        this.baseStateCount = list.get(0).getFrequencyModel().getFrequencyCount();
        this.baseMatrix = new double[this.baseStateCount * this.baseStateCount];
        for (int i = 0; i < this.numBaseModel; ++i) {
            this.addModel(list.get(i));
            arrayList.add(list.get(i).getFrequencyModel());
            this.addModel(list.get(i).getFrequencyModel());
            DataType dataType2 = list.get(i).getDataType();
            n += dataType2.getStateCount();
        }
        this.freqModel = new MarkovModulatedFrequencyModel("mm", arrayList, parameter, parameter3);
        this.addModel(this.freqModel);
        if (this.stateCount != n) {
            throw new RuntimeException("Incompatible state counts in " + this.getModelName() + " (currently: " + this.stateCount + "). Models add up to " + n + ".");
        }
        this.birthDeathModel = true;
        this.geometricRates = bl;
        if (this.numBaseModel > 1 && parameter.getDimension() != 2 * (this.numBaseModel - 1)) {
            this.birthDeathModel = false;
        }
        if (siteRateModel != null) {
            this.addModel(siteRateModel);
            if (siteRateModel.getCategoryCount() != this.numBaseModel && this.numBaseModel % siteRateModel.getCategoryCount() != 0) {
                throw new RuntimeException("Wrong discretized gamma dimension");
            }
        }
        this.gammaRateModel = siteRateModel;
        if (parameter2 != null) {
            this.addVariable(parameter2);
            if (parameter2.getDimension() != 1 && parameter2.getDimension() != this.numBaseModel) {
                throw new RuntimeException("Wrong rate scalar dimensions");
            }
        }
        this.rateScalar = parameter2;
        this.setDoNormalization(false);
        this.updateMatrix = true;
        Logger.getLogger("dr.app.beagle").info("\tConstructing a Markov-modulated Markov chain substitution model with " + this.stateCount + " states;  please cite:\n" + Citable.Utils.getCitationString(this));
    }

    public int getNumBaseModel() {
        return this.numBaseModel;
    }

    public double getModelRateScalar(int n) {
        if (this.gammaRateModel != null) {
            return this.gammaRateModel.getRateForCategory(n %= this.gammaRateModel.getCategoryCount());
        }
        if (this.rateScalar == null) {
            return 1.0;
        }
        if (this.rateScalar.getDimension() == 1) {
            return this.rateScalar.getParameterValue(0);
        }
        return this.rateScalar.getParameterValue(n);
    }

    @Override
    protected void storeState() {
        if (this.eigenDecomposition != null) {
            this.storedEigenDecomposition = this.eigenDecomposition.copy();
        }
        this.storedUpdateMatrix = this.updateMatrix;
    }

    @Override
    protected void restoreState() {
        EigenDecomposition eigenDecomposition = this.storedEigenDecomposition;
        this.storedEigenDecomposition = this.eigenDecomposition;
        this.eigenDecomposition = eigenDecomposition;
        this.updateMatrix = this.storedUpdateMatrix;
    }

    @Override
    protected void setupQMatrix(double[] dArray, double[] dArray2, double[][] dArray3) {
        int n;
        int n2;
        int n3;
        for (n3 = 0; n3 < dArray3.length; ++n3) {
            Arrays.fill(dArray3[n3], 0.0);
        }
        for (n3 = 0; n3 < this.numBaseModel; ++n3) {
            int n4 = n3 * this.baseStateCount;
            this.baseModels.get(n3).getInfinitesimalMatrix(this.baseMatrix);
            double d = this.getModelRateScalar(n3);
            n2 = 0;
            for (n = 0; n < this.baseStateCount; ++n) {
                for (int i = 0; i < this.baseStateCount; ++i) {
                    dArray3[n4 + n][n4 + i] = d * this.baseMatrix[n2];
                    ++n2;
                }
            }
        }
        if (this.numBaseModel > 1) {
            double[] dArray4 = this.switchingRates.getParameterValues();
            double d = 0.0;
            double[] dArray5 = dArray4;
            n2 = dArray5.length;
            for (n = 0; n < n2; ++n) {
                double d2 = dArray5[n];
                d += d2;
            }
            int n5 = 0;
            for (n2 = 0; n2 < this.numBaseModel; ++n2) {
                for (n = 0; n < this.numBaseModel; ++n) {
                    boolean bl;
                    boolean bl2 = this.birthDeathModel ? Math.abs(n2 - n) == 1 : (bl = n2 != n);
                    if (!bl) continue;
                    double d3 = dArray4[n5];
                    if (this.geometricRates) {
                        d3 *= this.getModelRateScalar(this.numBaseModel - n - 1) / d;
                    }
                    for (int i = 0; i < this.baseStateCount; ++i) {
                        dArray3[n2 * this.baseStateCount + i][n * this.baseStateCount + i] = d3;
                    }
                    ++n5;
                }
            }
        }
    }

    @Override
    public EigenDecomposition getEigenDecomposition() {
        EigenDecomposition eigenDecomposition = super.getEigenDecomposition();
        return eigenDecomposition;
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.SUBSTITUTION_MODELS;
    }

    @Override
    public String getDescription() {
        return "Markov modulated substitution model";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(CommonCitations.SUCHARD_2020_MMM);
    }

    @Override
    protected void frequenciesChanged() {
    }

    @Override
    protected void ratesChanged() {
        this.updateMatrix = true;
    }

    @Override
    protected void setupRelativeRates(double[] dArray) {
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        this.updateMatrix = true;
        this.fireModelChanged();
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable == this.switchingRates || variable == this.rateScalar) {
            this.updateMatrix = true;
            this.fireModelChanged();
        }
    }

    @Override
    public LogColumn[] getColumns() {
        ArrayList<LogColumn> arrayList = new ArrayList<LogColumn>();
        for (LogColumn logColumn : super.getColumns()) {
            arrayList.add(logColumn);
        }
        for (int i = 0; i < this.numBaseModel; ++i) {
            String string = "rateScalar." + i;
            arrayList.add(new RateColumn(string, i));
        }
        return arrayList.toArray(new LogColumn[0]);
    }

    @Override
    public ParameterReplaceableSubstitutionModel factory(List<Parameter> list, List<Parameter> list2) {
        ArrayList<SubstitutionModel> arrayList = new ArrayList<SubstitutionModel>();
        Parameter parameter = list.contains(this.switchingRates) ? list2.get(list.indexOf(this.switchingRates)) : this.switchingRates;
        Parameter parameter2 = list.contains(this.rateScalar) ? list2.get(list.indexOf(this.rateScalar)) : this.rateScalar;
        for (int i = 0; i < this.baseModels.size(); ++i) {
            ParameterReplaceableSubstitutionModel parameterReplaceableSubstitutionModel = (ParameterReplaceableSubstitutionModel)this.baseModels.get(i);
            arrayList.add(parameterReplaceableSubstitutionModel.factory(list, list2));
        }
        return new MarkovModulatedSubstitutionModel(this.getModelName(), arrayList, parameter, this.dataType, null, parameter2, this.geometricRates, this.gammaRateModel, this.relativeWeights);
    }

    @Override
    public WrappedMatrix getInfinitesimalDifferentialMatrix(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter) {
        BaseWrtParameter baseWrtParameter = (BaseWrtParameter)wrtParameter;
        WrappedMatrix wrappedMatrix = ((DifferentiableSubstitutionModel)this.baseModels.get(baseWrtParameter.getBaseModelIndex())).getInfinitesimalDifferentialMatrix(baseWrtParameter.getBaseWrtParameter());
        double[][] dArray = new double[this.baseStateCount * this.numBaseModel][this.baseStateCount * this.numBaseModel];
        for (int i = 0; i < this.numBaseModel; ++i) {
            int n;
            if (i == baseWrtParameter.getBaseModelIndex()) {
                for (n = 0; n < this.baseStateCount; ++n) {
                    for (int j = 0; j < this.baseStateCount; ++j) {
                        dArray[i * this.baseStateCount + n][i * this.baseStateCount + j] = wrappedMatrix.get(n, j);
                    }
                }
                continue;
            }
            for (n = 0; n < this.baseStateCount; ++n) {
                Arrays.fill(dArray[i * this.baseStateCount + n], 0.0);
            }
        }
        return new WrappedMatrix.ArrayOfArray(dArray);
    }

    @Override
    public DifferentialMassProvider.DifferentialWrapper.WrtParameter factory(Parameter parameter, int n) {
        for (int i = 0; i < this.numBaseModel; ++i) {
            DifferentiableSubstitutionModel differentiableSubstitutionModel = (DifferentiableSubstitutionModel)this.baseModels.get(i);
            try {
                DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter = differentiableSubstitutionModel.factory(parameter, n);
                return new BaseWrtParameter(parameter, n, wrtParameter, i);
            }
            catch (RuntimeException runtimeException) {
                continue;
            }
        }
        throw new RuntimeException("Parameter not found in any base model");
    }

    @Override
    public void setupDifferentialRates(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, double[] dArray, double d) {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public void setupDifferentialFrequency(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, double[] dArray) {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public double getWeightedNormalizationGradient(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, double[][] dArray, double[] dArray2) {
        throw new RuntimeException("Not yet implemented!");
    }

    private class RateColumn
    extends NumberColumn {
        private final int index;

        public RateColumn(String string, int n) {
            super(string);
            this.index = n;
        }

        @Override
        public double getDoubleValue() {
            return MarkovModulatedSubstitutionModel.this.getModelRateScalar(this.index);
        }
    }

    class BaseWrtParameter
    implements DifferentialMassProvider.DifferentialWrapper.WrtParameter {
        Parameter parameter;
        int parameterDim;
        int baseModelIndex;
        DifferentialMassProvider.DifferentialWrapper.WrtParameter baseWrtParameter;

        public BaseWrtParameter(Parameter parameter, int n, DifferentialMassProvider.DifferentialWrapper.WrtParameter wrtParameter, int n2) {
            this.parameter = parameter;
            this.parameterDim = n;
            this.baseModelIndex = n2;
            this.baseWrtParameter = wrtParameter;
        }

        public int getBaseModelIndex() {
            return this.baseModelIndex;
        }

        public DifferentialMassProvider.DifferentialWrapper.WrtParameter getBaseWrtParameter() {
            return this.baseWrtParameter;
        }

        @Override
        public double getRate(int n) {
            throw new RuntimeException("Not yet implemented!");
        }

        @Override
        public double getNormalizationDifferential() {
            throw new RuntimeException("Not yet implemented!");
        }

        @Override
        public void setupDifferentialFrequencies(double[] dArray, double[] dArray2) {
            throw new RuntimeException("Not yet implemented!");
        }

        @Override
        public void setupDifferentialRates(double[] dArray, double[] dArray2, double d) {
            throw new RuntimeException("Not yet implemented!");
        }
    }
}

