/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.Discretizer;
import edu.cmu.tetrad.data.VerticalIntDataBox;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.AdLeafTree;
import edu.cmu.tetrad.search.AdTrees;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.CholeskyDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.correlation.Covariance;
import org.apache.commons.math3.util.FastMath;

public class ConditionalGaussianOtherLikelihood {
    private final DataSet mixedDataSet;
    private final DataSet dataSet;
    private int numCategoriesToDiscretize = 3;
    private final List<Node> mixedVariables;
    private final Map<Node, Integer> nodesHash;
    private final double[][] continuousData;
    private final AdLeafTree adTree;
    private double penaltyDiscount = 1.0;
    private final ArrayList<Integer> all;
    private static final double LOGMATH2PI = FastMath.log(Math.PI * 2);

    public ConditionalGaussianOtherLikelihood(DataSet dataSet) {
        Node v;
        int j;
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.mixedDataSet = dataSet;
        this.mixedVariables = dataSet.getVariables();
        this.continuousData = new double[dataSet.getNumColumns()][];
        for (j = 0; j < dataSet.getNumColumns(); ++j) {
            v = dataSet.getVariable(j);
            if (!(v instanceof ContinuousVariable)) continue;
            double[] col = new double[dataSet.getNumRows()];
            for (int i = 0; i < dataSet.getNumRows(); ++i) {
                col[i] = dataSet.getDouble(i, j);
            }
            this.continuousData[j] = col;
        }
        this.nodesHash = new HashMap<Node, Integer>();
        for (j = 0; j < dataSet.getNumColumns(); ++j) {
            v = dataSet.getVariable(j);
            this.nodesHash.put(v, j);
        }
        this.dataSet = this.useErsatzVariables();
        this.adTree = AdTrees.getAdLeafTree(this.dataSet);
        this.all = new ArrayList();
        for (int i = 0; i < dataSet.getNumRows(); ++i) {
            this.all.add(i);
        }
    }

    private DataSet useErsatzVariables() {
        ArrayList<Node> nodes = new ArrayList<Node>();
        int numCategories = this.numCategoriesToDiscretize;
        for (Node x : this.mixedVariables) {
            if (x instanceof ContinuousVariable) {
                nodes.add(new DiscreteVariable(x.getName(), numCategories));
                continue;
            }
            nodes.add(x);
        }
        BoxDataSet replaced = new BoxDataSet(new VerticalIntDataBox(this.mixedDataSet.getNumRows(), this.mixedDataSet.getNumColumns()), nodes);
        for (int j = 0; j < this.mixedVariables.size(); ++j) {
            if (this.mixedVariables.get(j) instanceof DiscreteVariable) {
                for (int i = 0; i < this.mixedDataSet.getNumRows(); ++i) {
                    replaced.setInt(i, j, this.mixedDataSet.getInt(i, j));
                }
                continue;
            }
            double[] column = this.continuousData[j];
            double[] breakpoints = Discretizer.getEqualFrequencyBreakPoints(column, numCategories);
            ArrayList<String> categoryNames = new ArrayList<String>();
            for (int i = 0; i < numCategories; ++i) {
                categoryNames.add("" + i);
            }
            Discretizer.Discretization d = Discretizer.discretize(column, breakpoints, this.mixedVariables.get(j).getName(), categoryNames);
            for (int i = 0; i < this.mixedDataSet.getNumRows(); ++i) {
                replaced.setInt(i, j, d.getData()[i]);
            }
        }
        return replaced;
    }

    public Ret getLikelihood(int i, int[] parents) {
        Node target = this.mixedVariables.get(i);
        ArrayList<ContinuousVariable> X = new ArrayList<ContinuousVariable>();
        ArrayList<DiscreteVariable> A = new ArrayList<DiscreteVariable>();
        for (int p : parents) {
            Node parent = this.mixedVariables.get(p);
            if (parent instanceof ContinuousVariable) {
                X.add((ContinuousVariable)parent);
                continue;
            }
            A.add((DiscreteVariable)parent);
        }
        if (target instanceof DiscreteVariable && X.size() > 0) {
            return this.likelihoodMixed(X, A, (DiscreteVariable)target);
        }
        ArrayList<ContinuousVariable> XPlus = new ArrayList<ContinuousVariable>(X);
        ArrayList<DiscreteVariable> APlus = new ArrayList<DiscreteVariable>(A);
        if (target instanceof ContinuousVariable) {
            XPlus.add((ContinuousVariable)target);
        } else if (target instanceof DiscreteVariable) {
            APlus.add((DiscreteVariable)target);
        }
        Ret ret1 = this.likelihoodJoint(XPlus, APlus, target);
        Ret ret2 = this.likelihoodJoint(X, A, target);
        return new Ret(ret1.getLik() - ret2.getLik(), ret1.getDof() - ret2.getDof());
    }

    public double getPenaltyDiscount() {
        return this.penaltyDiscount;
    }

    public void setPenaltyDiscount(double penaltyDiscount) {
        this.penaltyDiscount = penaltyDiscount;
    }

    public void setNumCategoriesToDiscretize(int numCategoriesToDiscretize) {
        this.numCategoriesToDiscretize = numCategoriesToDiscretize;
    }

    private Ret likelihoodJoint(List<ContinuousVariable> X, List<DiscreteVariable> A, Node target) {
        A = new ArrayList<DiscreteVariable>(A);
        X = new ArrayList<ContinuousVariable>(X);
        if (target instanceof DiscreteVariable) {
            for (ContinuousVariable x : new ArrayList<ContinuousVariable>(X)) {
                Node variable = this.dataSet.getVariable(x.getName());
                if (variable == null) continue;
                A.add((DiscreteVariable)variable);
                X.remove(x);
            }
        }
        int k = X.size();
        int[] continuousCols = new int[k];
        for (int j = 0; j < k; ++j) {
            continuousCols[j] = this.nodesHash.get(X.get(j));
        }
        int N = this.mixedDataSet.getNumRows();
        double c1 = 0.0;
        double c2 = 0.0;
        List<List<Integer>> cells = this.adTree.getCellLeaves(A);
        for (List<Integer> cell : cells) {
            int a = cell.size();
            if (a == 0) continue;
            if (A.size() > 0) {
                c1 += (double)a * this.multinomialLikelihood(a, N);
            }
            if (X.size() <= 0) continue;
            try {
                Matrix cov = a > continuousCols.length + 10 ? this.cov(this.getSubsample(continuousCols, cell)) : this.cov(this.getSubsample(continuousCols, this.all));
                c2 += (double)a * this.gaussianLikelihood(k, cov);
            }
            catch (Exception cov) {}
        }
        double lnL = c1 + c2;
        int p = (int)this.getPenaltyDiscount();
        int dof = p * this.f(A) * this.h(X) + this.f(A);
        return new Ret(lnL, dof);
    }

    private double multinomialLikelihood(int a, int N) {
        return FastMath.log((double)a / (double)N);
    }

    private double gaussianLikelihood(int k, Matrix sigma) {
        return -0.5 * this.logdet(sigma) - 0.5 * (double)k - 0.5 * (double)k * LOGMATH2PI;
    }

    private double logdet(Matrix m) {
        BlockRealMatrix M = new BlockRealMatrix(m.toArray());
        double tol = 1.0E-9;
        RealMatrix LT = new CholeskyDecomposition(M, 1.0E-9, 1.0E-9).getLT();
        double sum = 0.0;
        for (int i = 0; i < LT.getRowDimension(); ++i) {
            sum += FastMath.log(LT.getEntry(i, i));
        }
        return 2.0 * sum;
    }

    private Ret likelihoodMixed(List<ContinuousVariable> X, List<DiscreteVariable> A, DiscreteVariable B) {
        int k = X.size();
        double g = FastMath.pow(Math.PI * 2, -0.5 * (double)k) * FastMath.exp(-0.5 * (double)k);
        int[] continuousCols = new int[k];
        for (int j = 0; j < k; ++j) {
            continuousCols[j] = this.nodesHash.get(X.get(j));
        }
        double lnL = 0.0;
        int N = this.dataSet.getNumRows();
        List<List<List<Integer>>> cells = this.adTree.getCellLeaves(A, B);
        Matrix defaultCov = null;
        for (List<List<Integer>> mycells : cells) {
            ArrayList<Matrix> x = new ArrayList<Matrix>();
            ArrayList<Matrix> sigmas = new ArrayList<Matrix>();
            ArrayList<Matrix> inv = new ArrayList<Matrix>();
            ArrayList<Vector> mu = new ArrayList<Vector>();
            for (List<Integer> cell : mycells) {
                Matrix subsample = this.getSubsample(continuousCols, cell);
                try {
                    if (mycells.size() <= continuousCols.length) {
                        throw new IllegalArgumentException();
                    }
                    Matrix cov = this.cov(subsample);
                    Matrix covinv = cov.inverse();
                    if (defaultCov == null) {
                        defaultCov = cov;
                    }
                    x.add(subsample);
                    sigmas.add(cov);
                    inv.add(covinv);
                    mu.add(this.means(subsample));
                }
                catch (Exception cov) {}
            }
            double[] factors = new double[x.size()];
            for (int u = 0; u < x.size(); ++u) {
                factors[u] = g * FastMath.pow(((Matrix)sigmas.get(u)).det(), -0.5);
            }
            double[] a = new double[x.size()];
            for (int u = 0; u < x.size(); ++u) {
                for (int i = 0; i < ((Matrix)x.get(u)).rows(); ++i) {
                    for (int v = 0; v < x.size(); ++v) {
                        Vector xm = ((Matrix)x.get(u)).getRow(i).minus((Vector)mu.get(v));
                        a[v] = this.prob(factors[v], (Matrix)inv.get(v), xm);
                    }
                    double num = a[u] * this.p(x, u, N);
                    double denom = 0.0;
                    for (int v = 0; v < x.size(); ++v) {
                        denom += a[v] * this.p(x, v, N);
                    }
                    lnL += FastMath.log(num) - FastMath.log(denom);
                }
            }
        }
        int p = (int)this.getPenaltyDiscount();
        int dof = this.f(A) * B.getNumCategories() + this.f(A) * p * this.h(X);
        return new Ret(lnL, dof);
    }

    private double p(List<Matrix> x, int u, double N) {
        return (double)x.get(u).rows() / N;
    }

    private Matrix cov(Matrix x) {
        return new Matrix(new Covariance(x.toArray(), true).getCovarianceMatrix().getData());
    }

    private double prob(Double factor, Matrix inv, Vector x) {
        return factor * FastMath.exp(-0.5 * inv.times(x).dotProduct(x));
    }

    private Vector means(Matrix x) {
        return x.sum(1).scalarMult(1.0 / (double)x.rows());
    }

    private Matrix getSubsample(int[] continuousCols, List<Integer> cell) {
        Matrix subset = new Matrix(cell.size(), continuousCols.length);
        for (int i = 0; i < cell.size(); ++i) {
            for (int j = 0; j < continuousCols.length; ++j) {
                subset.set(i, j, this.continuousData[continuousCols[j]][cell.get(i)]);
            }
        }
        return subset;
    }

    private int f(List<DiscreteVariable> A) {
        int f = 1;
        for (DiscreteVariable V : A) {
            f *= V.getNumCategories();
        }
        return f;
    }

    private int h(List<ContinuousVariable> X) {
        int p = X.size();
        return p * (p + 1) / 2;
    }

    public static class Ret {
        private final double lik;
        private final int dof;

        private Ret(double lik, int dof) {
            this.lik = lik;
            this.dof = dof;
        }

        public double getLik() {
            return this.lik;
        }

        public int getDof() {
            return this.dof;
        }

        public String toString() {
            return "lik = " + this.lik + " dof = " + this.dof;
        }
    }
}

