/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.basta;

import dr.evolution.coalescent.IntervalType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.bigfasttree.BigFastTreeIntervals;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.coalescent.basta.ProcessOnCoalescentIntervalDelegate;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.TreeTraversal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

public class CoalescentIntervalTraversal
extends TreeTraversal {
    private final BigFastTreeIntervals treeIntervals;
    private final int numberSubIntervals;
    private int currentMatrixNumber;
    private int currentLikelihoodInterval;
    private static final boolean SWAP_API = false;
    private final List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> branchIntervalOperations = new ArrayList<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation>();
    private final List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> matrixOperations = new ArrayList<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation>();
    private final List<Integer> intervalStarts = new ArrayList<Integer>();
    private final ProcessOnCoalescentIntervalDelegate.BranchIntervalOperationList branchIntervalOperationList = null;

    protected CoalescentIntervalTraversal(Tree tree, BigFastTreeIntervals bigFastTreeIntervals, BranchRateModel branchRateModel, int n) {
        super(tree, branchRateModel, TreeTraversal.TraversalType.REVERSE_LEVEL_ORDER);
        assert (tree instanceof TreeModel);
        this.treeIntervals = bigFastTreeIntervals;
        this.numberSubIntervals = n;
    }

    @Override
    public final void dispatchTreeTraversalCollectBranchAndNodeOperations() {
        this.matrixOperations.clear();
        this.branchIntervalOperations.clear();
        this.intervalStarts.clear();
        if (this.traversalType == TreeTraversal.TraversalType.REVERSE_LEVEL_ORDER) {
            this.traverseReverseCoalescentLevelOrder();
        } else assert (false) : "Unknown traversal type";
    }

    public List<ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation> getBranchIntervalOperations() {
        return this.branchIntervalOperations;
    }

    public List<ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation> getMatrixOperations() {
        return this.matrixOperations;
    }

    public int getCoalescentIntervalCount() {
        return this.currentLikelihoodInterval + 1;
    }

    public List<Integer> getIntervalStarts() {
        return this.intervalStarts;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected final double computeRateScaledIntervalLength(Tree tree, NodeRef nodeRef, double d) {
        double d2;
        BranchRateModel branchRateModel = this.branchRateModel;
        synchronized (branchRateModel) {
            d2 = this.branchRateModel.getBranchRate(tree, nodeRef);
        }
        double d3 = d2 * d;
        assert (d3 >= 0.0) : "Negative interval length: " + d3 + " for node " + nodeRef.getNumber() + (tree.isExternal(nodeRef) ? " (" + tree.getNodeTaxon(nodeRef).getId() + ")" : "");
        return d3;
    }

    private int determineStartingInterval() {
        int n = 0;
        for (int i = 0; i < this.updateNode.length; ++i) {
            if (!this.updateNode[i]) continue;
        }
        return n;
    }

    private void traverseReverseCoalescentLevelOrder() {
        this.currentLikelihoodInterval = 0;
        this.currentMatrixNumber = -1;
        ActiveNodesForInterval activeNodesForInterval = new ActiveNodesForInterval(this.treeModel.getNodeCount());
        activeNodesForInterval.add(this.treeIntervals.getSamplingNode(-1));
        this.intervalStarts.add(0);
        for (int i = 0; i < this.treeIntervals.getIntervalCount(); ++i) {
            IntervalType intervalType = this.treeIntervals.getIntervalType(i);
            if (intervalType == IntervalType.COALESCENT) {
                this.processCoalescentEvent(i, activeNodesForInterval);
            } else if (intervalType == IntervalType.SAMPLE) {
                this.processSamplingEvent(i, activeNodesForInterval);
            } else {
                throw new RuntimeException("Unknown interval type");
            }
            if (i != this.treeIntervals.getIntervalCount() - 1 || intervalType == IntervalType.COALESCENT) continue;
            throw new RuntimeException("Not a coalescence at top");
        }
    }

    private int getDecompositionNumber(NodeRef nodeRef) {
        return 0;
    }

    private int computeTransmissionProbabilities(int n, NodeRef nodeRef, double d) {
        int n2 = n;
        if (n2 != this.currentMatrixNumber) {
            double d2 = this.computeRateScaledIntervalLength(this.treeModel, nodeRef, d);
            this.matrixOperations.add(new ProcessOnCoalescentIntervalDelegate.TransitionMatrixOperation(n2, this.getDecompositionNumber(nodeRef), d2));
            this.currentMatrixNumber = n2;
        }
        return n2;
    }

    private void propagateTransmissionProbabilities(int n, NodeRef nodeRef, double d, ActiveNodesForInterval activeNodesForInterval) {
        int n2 = activeNodesForInterval.getActiveBuffer(nodeRef);
        activeNodesForInterval.incrementActiveBuffer(nodeRef);
        int n3 = activeNodesForInterval.getActiveBuffer(nodeRef);
        int n4 = activeNodesForInterval.getExecutionOrder(nodeRef) + 1;
        int n5 = this.computeTransmissionProbabilities(n, nodeRef, d);
        ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation branchIntervalOperation = new ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation(n3, n2, -1, n5, -1, n3, -1, d, n4, n);
        this.branchIntervalOperations.add(branchIntervalOperation);
        activeNodesForInterval.setExecutionOrder(nodeRef, n4);
    }

    private void coalescenceTransmissionProbabilities(int n, NodeRef nodeRef, NodeRef nodeRef2, NodeRef nodeRef3, double d, ActiveNodesForInterval activeNodesForInterval) {
        int n2 = activeNodesForInterval.getActiveBuffer(nodeRef2);
        int n3 = activeNodesForInterval.getActiveBuffer(nodeRef3);
        int n4 = activeNodesForInterval.getAccumulationBuffer(nodeRef2);
        int n5 = activeNodesForInterval.getAccumulationBuffer(nodeRef3);
        int n6 = activeNodesForInterval.getActiveBuffer(nodeRef);
        int n7 = Math.max(activeNodesForInterval.getExecutionOrder(nodeRef2), activeNodesForInterval.getExecutionOrder(nodeRef3)) + 1;
        int n8 = this.computeTransmissionProbabilities(n, nodeRef2, d);
        int n9 = this.computeTransmissionProbabilities(n, nodeRef2, d);
        ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation branchIntervalOperation = new ProcessOnCoalescentIntervalDelegate.BranchIntervalOperation(n6, n2, n3, n8, n9, n4, n5, d, n7, n);
        this.branchIntervalOperations.add(branchIntervalOperation);
        activeNodesForInterval.setExecutionOrder(nodeRef, n7);
    }

    private void processCoalescentEvent(int n, ActiveNodesForInterval activeNodesForInterval) {
        int n2;
        NodeRef nodeRef = this.treeIntervals.getCoalescentNode(n);
        double d = this.treeIntervals.getInterval(n) / (double)this.numberSubIntervals;
        NodeRef nodeRef2 = this.treeModel.getChild(nodeRef, 0);
        NodeRef nodeRef3 = this.treeModel.getChild(nodeRef, 1);
        if (d <= 0.0) {
            throw new RuntimeException("Cannot coalescence in <= 0.0 time");
        }
        int n3 = this.currentLikelihoodInterval * this.numberSubIntervals;
        for (n2 = 0; n2 < this.numberSubIntervals - 1; ++n2) {
            for (Object object : activeNodesForInterval) {
                this.propagateTransmissionProbabilities(n3, (NodeRef)object, d, activeNodesForInterval);
            }
            ++n3;
            ++this.currentLikelihoodInterval;
            this.intervalStarts.add(this.branchIntervalOperations.size());
        }
        activeNodesForInterval.add(nodeRef);
        this.coalescenceTransmissionProbabilities(n3, nodeRef, nodeRef2, nodeRef3, d, activeNodesForInterval);
        n2 = activeNodesForInterval.remove(nodeRef2) ? 1 : 0;
        boolean bl = activeNodesForInterval.remove(nodeRef3);
        if (n2 == 0 || !bl) {
            throw new RuntimeException("Missing node");
        }
        for (NodeRef nodeRef4 : activeNodesForInterval) {
            if (nodeRef4 == nodeRef) continue;
            this.propagateTransmissionProbabilities(n3, nodeRef4, d, activeNodesForInterval);
        }
        ++this.currentLikelihoodInterval;
        this.intervalStarts.add(this.branchIntervalOperations.size());
    }

    private void processSamplingEvent(int n, ActiveNodesForInterval activeNodesForInterval) {
        NodeRef nodeRef = this.treeIntervals.getSamplingNode(n);
        double d = this.treeIntervals.getInterval(n);
        if (d > 0.0) {
            double d2 = d / (double)this.numberSubIntervals;
            int n2 = this.currentLikelihoodInterval * this.numberSubIntervals;
            for (int i = 0; i < this.numberSubIntervals; ++i) {
                for (NodeRef nodeRef2 : activeNodesForInterval) {
                    this.propagateTransmissionProbabilities(n2, nodeRef2, d2, activeNodesForInterval);
                }
                ++n2;
                ++this.currentLikelihoodInterval;
                this.intervalStarts.add(this.branchIntervalOperations.size());
            }
        }
        activeNodesForInterval.add(nodeRef);
    }

    static class ActiveNodesForInterval
    implements Set<NodeRef> {
        private static final boolean DEBUG = true;
        private final Set<NodeRef> activeSet = new HashSet<NodeRef>();
        private final int[] currentOffset;
        private final int[] executionOrder;
        private final int stride;
        private final List<NodeRef> intervalNodeOrder = new ArrayList<NodeRef>();

        public ActiveNodesForInterval(int n) {
            this.currentOffset = new int[n];
            this.executionOrder = new int[n];
            this.stride = n;
        }

        public Set<NodeRef> copy() {
            return new HashSet<NodeRef>(this.activeSet);
        }

        private void test(NodeRef nodeRef) {
            if (!this.activeSet.contains(nodeRef)) {
                throw new RuntimeException("Not in active set");
            }
        }

        public int getCurrentOffset(NodeRef nodeRef) {
            this.test(nodeRef);
            return this.currentOffset[nodeRef.getNumber()];
        }

        public int getActiveBuffer(NodeRef nodeRef) {
            this.test(nodeRef);
            int n = this.getCurrentOffset(nodeRef);
            if (n > 0) {
                ++n;
            }
            return n * this.stride + nodeRef.getNumber();
        }

        public int getAccumulationBuffer(NodeRef nodeRef) {
            return this.stride + nodeRef.getNumber();
        }

        public int getExecutionOrder(NodeRef nodeRef) {
            this.test(nodeRef);
            return this.executionOrder[nodeRef.getNumber()];
        }

        public void incrementActiveBuffer(NodeRef nodeRef) {
            this.test(nodeRef);
            int n = nodeRef.getNumber();
            this.currentOffset[n] = this.currentOffset[n] + 1;
        }

        public void incrementExecutionOrder(NodeRef nodeRef) {
            this.test(nodeRef);
            int n = nodeRef.getNumber();
            this.executionOrder[n] = this.executionOrder[n] + 1;
        }

        public void setExecutionOrder(NodeRef nodeRef, int n) {
            this.test(nodeRef);
            this.executionOrder[nodeRef.getNumber()] = n;
        }

        public int getNodeOrder(NodeRef nodeRef) {
            for (int i = 0; i < this.intervalNodeOrder.size(); ++i) {
                if (nodeRef != this.intervalNodeOrder.get(i)) continue;
                return i;
            }
            return -1;
        }

        @Override
        public int size() {
            return this.activeSet.size();
        }

        @Override
        public boolean isEmpty() {
            return this.activeSet.isEmpty();
        }

        @Override
        public boolean contains(Object object) {
            return this.activeSet.contains(object);
        }

        @Override
        public Iterator<NodeRef> iterator() {
            return this.activeSet.iterator();
        }

        @Override
        public Object[] toArray() {
            throw new UnsupportedOperationException();
        }

        @Override
        public <T> T[] toArray(T[] TArray) {
            throw new UnsupportedOperationException();
        }

        @Override
        public boolean add(NodeRef nodeRef) {
            this.intervalNodeOrder.add(nodeRef);
            return this.activeSet.add(nodeRef);
        }

        @Override
        public boolean remove(Object object) {
            return this.activeSet.remove(object);
        }

        @Override
        public boolean containsAll(Collection<?> collection) {
            throw new UnsupportedOperationException();
        }

        @Override
        public boolean addAll(Collection<? extends NodeRef> collection) {
            throw new UnsupportedOperationException();
        }

        @Override
        public boolean retainAll(Collection<?> collection) {
            throw new UnsupportedOperationException();
        }

        @Override
        public boolean removeAll(Collection<?> collection) {
            throw new UnsupportedOperationException();
        }

        @Override
        public void clear() {
            throw new UnsupportedOperationException();
        }
    }
}

