/*
 * Decompiled with CFR 0.152.
 */
package pal.eval;

import java.util.Vector;
import pal.alignment.SitePattern;
import pal.distance.AlignmentDistanceMatrix;
import pal.eval.BranchLikelihood;
import pal.eval.NodeLikelihood;
import pal.eval.RateLikelihood;
import pal.eval.TreeLikelihood;
import pal.math.DifferentialEvolution;
import pal.math.MultivariateMinimum;
import pal.math.NumericalDerivative;
import pal.math.UnivariateMinimum;
import pal.substmodel.SubstitutionModel;
import pal.tree.AttributeNode;
import pal.tree.ClockTree;
import pal.tree.DatedTipsClockTree;
import pal.tree.Node;
import pal.tree.NodeUtils;
import pal.tree.ParameterizedTree;
import pal.tree.Tree;
import pal.tree.TreeUtils;
import pal.tree.UnconstrainedTree;

public class LikelihoodValue {
    public double logL;
    public double[] siteLogL;
    public int[] rateAtSite;
    int numStates;
    int numRates;
    int numPatterns;
    double[] frequency;
    double[] rprob;
    SitePattern sitePattern;
    int numParams;
    Tree tree;
    ParameterizedTree ptree;
    private int MAXROUNDS = 1000;
    private SubstitutionModel model;
    private AlignmentDistanceMatrix distMat;
    private double[][][][] partials;
    private boolean down;
    private Node currentBranch;
    private UnivariateMinimum um;
    private MultivariateMinimum mvm;
    private BranchLikelihood bl;
    private TreeLikelihood tl;
    private NodeLikelihood nl;
    private RateLikelihood rl;
    private Vector shortBranches = null;

    public LikelihoodValue(SitePattern sp) {
        this.sitePattern = sp;
        this.numPatterns = sp.numPatterns;
        this.siteLogL = new double[this.numPatterns];
        this.rateAtSite = new int[this.numPatterns];
    }

    public SitePattern getSitePattern() {
        return this.sitePattern;
    }

    public void renewSitePattern(SitePattern sp) {
        this.sitePattern = sp;
        this.numPatterns = sp.numPatterns;
        this.siteLogL = new double[this.numPatterns];
        this.rateAtSite = new int[this.numPatterns];
        this.setModel(this.getModel());
        this.setTree(this.getTree());
    }

    public void setModel(SubstitutionModel m) {
        this.model = m;
        this.frequency = this.model.getRateMatrix().getEquilibriumFrequencies();
        this.rprob = this.model.getRateDistribution().probability;
        this.numStates = this.model.getRateMatrix().getDimension();
        this.numRates = this.model.getRateDistribution().numRates;
        int maxNodes = 2 * this.sitePattern.getSequenceCount() - 2;
        this.allocatePartialMemory(maxNodes);
    }

    public SubstitutionModel getModel() {
        return this.model;
    }

    public void setTree(Tree t) {
        this.tree = t;
        int[] alias = TreeUtils.mapExternalIdentifiers(this.sitePattern, this.tree);
        int i = 0;
        while (i < this.tree.getExternalNodeCount()) {
            this.tree.getExternalNode(i).setSequence(this.sitePattern.pattern[alias[i]]);
            ++i;
        }
        if (this.tree instanceof ParameterizedTree) {
            this.ptree = (ParameterizedTree)this.tree;
            this.numParams = this.ptree.getNumParameters();
        } else {
            this.ptree = null;
            this.numParams = 0;
        }
    }

    public Tree getTree() {
        return this.tree;
    }

    public double compute() {
        this.treeLikelihood();
        return this.logL;
    }

    public double optimiseParameters() {
        return this.optimiseParameters(null);
    }

    public double optimiseParameters(MultivariateMinimum mm) {
        int i;
        int ns;
        if (!(this.tree instanceof ParameterizedTree)) {
            new IllegalArgumentException("ParameterizedTree required");
        }
        double[] estimate = new double[this.numParams];
        if (this.um == null) {
            this.um = new UnivariateMinimum();
        }
        if (this.tl == null) {
            this.tl = new TreeLikelihood(this);
        }
        if (this.bl == null) {
            this.bl = new BranchLikelihood(this);
        } else {
            this.bl.update();
        }
        if (this.tree instanceof UnconstrainedTree) {
            this.optimiseUnconstrainedTree(true);
        } else if (this.tree instanceof ClockTree) {
            if (this.nl == null) {
                this.nl = new NodeLikelihood(this);
            } else {
                this.nl.update();
            }
            do {
                this.optimiseClockTree(false);
                ns = this.collapseShortInternalBranches();
                ((ClockTree)this.ptree).update();
                this.numParams -= ns;
            } while (ns != 0);
            this.numParams += this.restoreShortInternalBranches();
            ((ClockTree)this.ptree).update();
        } else if (this.tree instanceof DatedTipsClockTree) {
            if (this.nl == null) {
                this.nl = new NodeLikelihood(this);
            } else {
                this.nl.update();
            }
            if (this.rl == null) {
                this.rl = new RateLikelihood(this);
            } else {
                this.rl.update();
            }
            do {
                this.optimiseClockTree(true);
                ns = this.collapseShortInternalBranches();
                ((DatedTipsClockTree)this.ptree).update();
                this.numParams -= ns;
            } while (ns != 0);
            this.numParams += this.restoreShortInternalBranches();
            ((DatedTipsClockTree)this.ptree).update();
        } else {
            i = 0;
            while (i < this.numParams) {
                estimate[i] = this.ptree.getParameter(i);
                ++i;
            }
            if (mm == null) {
                if (this.mvm == null) {
                    this.mvm = new DifferentialEvolution(this.numParams);
                }
            } else {
                this.mvm = mm;
            }
            this.mvm.findMinimum(this.tl, estimate, 6, 6);
        }
        this.optimiseUnconstrainedTree(false);
        i = 0;
        while (i < this.numParams) {
            estimate[i] = this.ptree.getParameter(i);
            ++i;
        }
        return -this.tl.evaluate(estimate);
    }

    double[][][] getPartial(Node branch) {
        return this.partials[this.getKey(branch)];
    }

    Node getNextBranch(Node branch, Node center) {
        Node b = this.getNextBranchOrRoot(branch, center);
        if (b.isRoot()) {
            b = b.getChild(0);
        }
        return b;
    }

    void productPartials(Node branch, Node center) {
        int numBranches = this.getBranchCount(center);
        Node nextBranch = this.getNextBranch(branch, center);
        double[][][] partial = this.getPartial(nextBranch);
        int i = 0;
        while (i < numBranches - 2) {
            nextBranch = this.getNextBranch(nextBranch, center);
            double[][][] partial2 = this.getPartial(nextBranch);
            int l = 0;
            while (l < this.numPatterns) {
                int r = 0;
                while (r < this.numRates) {
                    double[] p = partial[l][r];
                    double[] p2 = partial2[l][r];
                    int d = 0;
                    while (d < this.numStates) {
                        int n = d;
                        p[n] = p[n] * p2[d];
                        ++d;
                    }
                    ++r;
                }
                ++l;
            }
            ++i;
        }
    }

    void partialsInternal(Node branch, Node center) {
        double[][][] partial = this.getPartial(branch);
        double[][][] multPartial = this.getPartial(this.getNextBranch(branch, center));
        this.model.setDistance(branch.getBranchLength());
        int l = 0;
        while (l < this.numPatterns) {
            int r = 0;
            while (r < this.numRates) {
                double[] p = partial[l][r];
                double[] mp = multPartial[l][r];
                int d = 0;
                while (d < this.numStates) {
                    double sum = 0.0;
                    int j = 0;
                    while (j < this.numStates) {
                        sum += this.model.transProb(r, d, j) * mp[j];
                        ++j;
                    }
                    p[d] = sum;
                    ++d;
                }
                ++r;
            }
            ++l;
        }
    }

    void partialsExternal(Node branch) {
        double[][][] partial = this.getPartial(branch);
        byte[] seq = branch.getSequence();
        this.model.setDistance(branch.getBranchLength());
        int l = 0;
        while (l < this.numPatterns) {
            int r = 0;
            while (r < this.numRates) {
                int d;
                double[] p = partial[l][r];
                byte sl = seq[l];
                if (sl == this.numStates) {
                    d = 0;
                    while (d < this.numStates) {
                        p[d] = 1.0;
                        ++d;
                    }
                } else {
                    d = 0;
                    while (d < this.numStates) {
                        p[d] = this.model.transProb(r, d, sl);
                        ++d;
                    }
                }
                ++r;
            }
            ++l;
        }
    }

    private void allocatePartialMemory(int numNodes) {
        if (this.partials == null || numNodes != this.partials.length || this.numPatterns != this.partials[0].length || this.numRates != this.partials[0][0].length || this.numStates != this.partials[0][0][0].length) {
            this.partials = new double[numNodes][this.numPatterns][this.numRates][this.numStates];
        }
    }

    private Node getNextBranchOrRoot(Node branch, Node center) {
        int numChilds = center.getChildCount();
        int num = 0;
        while (num < numChilds) {
            if (center.getChild(num) == branch) break;
            ++num;
        }
        if (++num > numChilds) {
            num = 0;
        }
        if (num == numChilds) {
            return center;
        }
        return center.getChild(num);
    }

    private int getKey(Node node) {
        int key = node.isLeaf() ? node.getNumber() : node.getNumber() + this.tree.getExternalNodeCount();
        return key;
    }

    private int getBranchCount(Node center) {
        if (center.isRoot()) {
            return center.getChildCount();
        }
        return center.getChildCount() + 1;
    }

    private void traverseTree() {
        if (!this.currentBranch.isLeaf() && this.down || this.currentBranch.isRoot()) {
            this.currentBranch = this.currentBranch.getChild(0);
            this.down = true;
        } else {
            Node center = this.currentBranch.getParent();
            this.currentBranch = this.getNextBranchOrRoot(this.currentBranch, center);
            this.down = this.currentBranch != center;
        }
    }

    private void initPartials() {
        this.currentBranch = this.tree.getRoot();
        this.down = true;
        Node firstBranch = this.currentBranch;
        do {
            if (!this.currentBranch.isRoot()) {
                if (this.currentBranch.isLeaf()) {
                    this.partialsExternal(this.currentBranch);
                } else if (!this.down) {
                    this.productPartials(this.currentBranch, this.currentBranch);
                    this.partialsInternal(this.currentBranch, this.currentBranch);
                }
            }
            this.traverseTree();
        } while (this.currentBranch != firstBranch);
    }

    private void treeLikelihood() {
        this.initPartials();
        Node center = this.tree.getRoot();
        Node firstBranch = center.getChild(0);
        Node lastBranch = center.getChild(center.getChildCount() - 1);
        double[][][] partial1 = this.getPartial(firstBranch);
        double[][][] partial2 = this.getPartial(lastBranch);
        this.productPartials(lastBranch, center);
        this.logL = 0.0;
        int l = 0;
        while (l < this.numPatterns) {
            int bestR = 0;
            double maxSum = 0.0;
            double rsum = 0.0;
            int r = 0;
            while (r < this.numRates) {
                double[] p1 = partial1[l][r];
                double[] p2 = partial2[l][r];
                double sum = 0.0;
                int d = 0;
                while (d < this.numStates) {
                    sum += this.frequency[d] * p1[d] * p2[d];
                    ++d;
                }
                sum *= this.rprob[r];
                if (r == 0) {
                    bestR = 0;
                    maxSum = sum;
                } else if (sum > maxSum) {
                    bestR = r;
                    maxSum = sum;
                }
                rsum += sum;
                ++r;
            }
            this.siteLogL[l] = Math.log(rsum);
            this.rateAtSite[l] = bestR;
            this.logL += this.siteLogL[l] * (double)this.sitePattern.weight[l];
            ++l;
        }
    }

    private void optimiseUnconstrainedTree(boolean optimise) {
        int numBranches = this.tree.getInternalNodeCount() + this.tree.getExternalNodeCount() - 1;
        this.initPartials();
        Node firstBranch = this.currentBranch;
        int nconv = 0;
        int numRounds = 0;
        double INVARC = 1.0E-4;
        while (true) {
            if (!this.currentBranch.isRoot()) {
                double lenDiff;
                double lenSE;
                double len;
                double lenOld;
                if (this.currentBranch.isLeaf()) {
                    this.productPartials(this.currentBranch, this.currentBranch.getParent());
                    this.bl.setBranch(this.currentBranch);
                    lenOld = this.currentBranch.getBranchLength();
                    if (optimise) {
                        len = this.um.findMinimum(lenOld, this.bl, 6);
                        this.currentBranch.setBranchLength(len);
                    } else {
                        len = lenOld;
                        lenSE = NumericalDerivative.secondDerivative(this.bl, lenOld);
                        lenSE = INVARC < lenSE ? Math.sqrt(1.0 / lenSE) : 100.0;
                        this.currentBranch.setBranchLengthSE(lenSE);
                    }
                    lenDiff = Math.abs(len - lenOld);
                    nconv = lenDiff < 5.0E-7 ? ++nconv : 0;
                    if (nconv >= numBranches || numRounds == this.MAXROUNDS) {
                        this.bl.evaluate(len);
                        break;
                    }
                    this.partialsExternal(this.currentBranch);
                } else if (this.down) {
                    this.productPartials(this.currentBranch, this.currentBranch.getParent());
                    this.partialsInternal(this.currentBranch, this.currentBranch.getParent());
                } else {
                    this.productPartials(this.currentBranch, this.currentBranch);
                    this.bl.setBranch(this.currentBranch);
                    lenOld = this.currentBranch.getBranchLength();
                    if (optimise) {
                        len = this.um.findMinimum(lenOld, this.bl, 6);
                        this.currentBranch.setBranchLength(len);
                    } else {
                        len = lenOld;
                        lenSE = NumericalDerivative.secondDerivative(this.bl, lenOld);
                        lenSE = INVARC < lenSE ? Math.sqrt(1.0 / lenSE) : 100.0;
                        this.currentBranch.setBranchLengthSE(lenSE);
                    }
                    lenDiff = Math.abs(len - lenOld);
                    nconv = lenDiff < 5.0E-7 ? ++nconv : 0;
                    if (nconv >= numBranches || numRounds == this.MAXROUNDS) {
                        this.bl.evaluate(len);
                        break;
                    }
                    this.partialsInternal(this.currentBranch, this.currentBranch);
                }
            }
            this.traverseTree();
            if (this.currentBranch != firstBranch) continue;
            ++numRounds;
        }
    }

    private int collapseShortInternalBranches() {
        int numInternalBranches = this.tree.getInternalNodeCount() - 1;
        int numShortBranches = 0;
        int i = 0;
        while (i < numInternalBranches) {
            Node b = this.tree.getInternalNode(i);
            if (b.getBranchLength() <= 2.0E-9) {
                ++numShortBranches;
                NodeUtils.removeBranch(b);
                if (this.shortBranches == null) {
                    this.shortBranches = new Vector();
                }
                this.shortBranches.addElement(b);
            }
            ++i;
        }
        this.tree.createNodeList();
        return numShortBranches;
    }

    private int restoreShortInternalBranches() {
        int size = 0;
        if (this.shortBranches != null) {
            size = this.shortBranches.size();
            int i = size - 1;
            while (i >= 0) {
                Node node = (Node)this.shortBranches.elementAt(i);
                NodeUtils.restoreBranch(node);
                node.setBranchLength(1.0E-9);
                node.setNodeHeight(node.getParent().getNodeHeight() - 1.0E-9);
                this.shortBranches.removeElementAt(i);
                --i;
            }
        }
        this.tree.createNodeList();
        return size;
    }

    private void optimiseClockTree(boolean datedTips) {
        int numNodes = this.tree.getInternalNodeCount();
        double MAXHEIGHT = (double)numNodes * 100.0;
        this.initPartials();
        Node firstBranch = this.currentBranch;
        int nconv = 0;
        int numRounds = 0;
        double INVMAX = 1.0 / (MAXHEIGHT * MAXHEIGHT);
        while (true) {
            double hDiff;
            double hSE;
            double h;
            double hOld;
            double hMax;
            double hMin;
            if (this.currentBranch.isRoot()) {
                if (datedTips && numRounds > 0) {
                    if (numRounds == 1) {
                        nconv = 0;
                    }
                    double oldLogL = this.logL;
                    DatedTipsClockTree dtree = (DatedTipsClockTree)this.ptree;
                    double rOld = dtree.getRate();
                    double maxR = dtree.getMaxRate();
                    double r = this.um.findMinimum(rOld, this.rl);
                    this.rl.evaluate(r);
                    double rSE = this.um.f2minx;
                    rSE = 1.0 < rSE ? Math.sqrt(1.0 / rSE) : 1.0;
                    dtree.setRateSE(rSE);
                }
                hMin = NodeUtils.findLargestChild(this.currentBranch) + 1.0E-9;
                hMax = MAXHEIGHT - 1.0E-9;
                this.nl.setBranch(this.currentBranch, hMin, hMax);
                hOld = this.currentBranch.getNodeHeight();
                h = this.um.findMinimum(hOld, this.nl, 6);
                this.nl.evaluate(h);
                hSE = this.um.f2minx;
                hSE = INVMAX < hSE ? Math.sqrt(1.0 / hSE) : MAXHEIGHT;
                if (this.currentBranch instanceof AttributeNode) {
                    ((AttributeNode)this.currentBranch).setAttribute("node height SE", new Double(hSE));
                }
                nconv = (hDiff = Math.abs(h - hOld)) < 5.0E-7 ? ++nconv : 0;
                if (nconv >= numNodes || numRounds == this.MAXROUNDS) {
                    break;
                }
            } else if (this.currentBranch.isLeaf()) {
                this.productPartials(this.currentBranch, this.currentBranch.getParent());
                this.partialsExternal(this.currentBranch);
            } else if (this.down) {
                this.productPartials(this.currentBranch, this.currentBranch.getParent());
                hMin = NodeUtils.findLargestChild(this.currentBranch) + 1.0E-9;
                hMax = this.currentBranch.getParent().getNodeHeight() - 1.0E-9;
                this.nl.setBranch(this.currentBranch, hMin, hMax);
                hOld = this.currentBranch.getNodeHeight();
                h = this.um.findMinimum(hOld, this.nl, 6);
                this.nl.evaluate(h);
                hSE = this.um.f2minx;
                hSE = INVMAX < hSE ? Math.sqrt(1.0 / hSE) : MAXHEIGHT;
                if (this.currentBranch instanceof AttributeNode) {
                    ((AttributeNode)this.currentBranch).setAttribute("node height SE", new Double(hSE));
                }
                nconv = (hDiff = Math.abs(h - hOld)) < 5.0E-7 ? ++nconv : 0;
                if (nconv >= numNodes || numRounds == this.MAXROUNDS) break;
                this.partialsInternal(this.currentBranch, this.currentBranch.getParent());
            } else {
                this.productPartials(this.currentBranch, this.currentBranch);
                this.partialsInternal(this.currentBranch, this.currentBranch);
            }
            this.traverseTree();
            if (this.currentBranch != firstBranch) continue;
            ++numRounds;
        }
    }
}

