/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.Discretizer;
import edu.cmu.tetrad.data.VerticalIntDataBox;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.Matrix;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.stat.correlation.Covariance;
import org.apache.commons.math3.util.FastMath;

public class ConditionalGaussianLikelihood {
    private final DataSet mixedDataSet;
    private final DataSet dataSet;
    private int numCategoriesToDiscretize = 3;
    private final List<Node> mixedVariables;
    private final Map<Node, Integer> nodesHash;
    private final double[][] continuousData;
    private double penaltyDiscount = 1.0;
    private List<Integer> rows;
    private boolean discretize;
    private static final double LOG2PI = FastMath.log(Math.PI * 2);

    public void setRows(List<Integer> rows) {
        this.rows = rows;
    }

    public ConditionalGaussianLikelihood(DataSet dataSet) {
        Node v;
        int j;
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.mixedDataSet = dataSet;
        this.mixedVariables = dataSet.getVariables();
        this.continuousData = new double[dataSet.getNumColumns()][];
        for (j = 0; j < dataSet.getNumColumns(); ++j) {
            v = dataSet.getVariable(j);
            if (!(v instanceof ContinuousVariable)) continue;
            double[] col = new double[dataSet.getNumRows()];
            for (int i = 0; i < dataSet.getNumRows(); ++i) {
                col[i] = dataSet.getDouble(i, j);
            }
            this.continuousData[j] = col;
        }
        this.nodesHash = new HashMap<Node, Integer>();
        for (j = 0; j < dataSet.getNumColumns(); ++j) {
            v = dataSet.getVariable(j);
            this.nodesHash.put(v, j);
        }
        this.dataSet = this.useErsatzVariables();
        this.rows = new ArrayList<Integer>();
        for (int i = 0; i < dataSet.getNumRows(); ++i) {
            this.rows.add(i);
        }
    }

    private DataSet useErsatzVariables() {
        ArrayList<Node> nodes = new ArrayList<Node>();
        int numCategories = this.numCategoriesToDiscretize;
        for (Node x : this.mixedVariables) {
            if (x instanceof ContinuousVariable) {
                nodes.add(new DiscreteVariable(x.getName(), numCategories));
                continue;
            }
            nodes.add(x);
        }
        BoxDataSet replaced = new BoxDataSet(new VerticalIntDataBox(this.mixedDataSet.getNumRows(), this.mixedDataSet.getNumColumns()), nodes);
        for (int j = 0; j < this.mixedVariables.size(); ++j) {
            if (this.mixedVariables.get(j) instanceof DiscreteVariable) {
                for (int i = 0; i < this.mixedDataSet.getNumRows(); ++i) {
                    replaced.setInt(i, j, this.mixedDataSet.getInt(i, j));
                }
                continue;
            }
            double[] column = this.continuousData[j];
            double[] breakpoints = Discretizer.getEqualFrequencyBreakPoints(column, numCategories);
            ArrayList<String> categoryNames = new ArrayList<String>();
            for (int i = 0; i < numCategories; ++i) {
                categoryNames.add("" + i);
            }
            Discretizer.Discretization d = Discretizer.discretize(column, breakpoints, this.mixedVariables.get(j).getName(), categoryNames);
            for (int i = 0; i < this.mixedDataSet.getNumRows(); ++i) {
                replaced.setInt(i, j, d.getData()[i]);
            }
        }
        return replaced;
    }

    public Ret getLikelihood(int i, int[] parents) {
        Node target = this.mixedVariables.get(i);
        ArrayList<ContinuousVariable> X = new ArrayList<ContinuousVariable>();
        ArrayList<DiscreteVariable> A = new ArrayList<DiscreteVariable>();
        for (int p : parents) {
            Node parent = this.mixedVariables.get(p);
            if (parent instanceof ContinuousVariable) {
                X.add((ContinuousVariable)parent);
                continue;
            }
            A.add((DiscreteVariable)parent);
        }
        ArrayList<ContinuousVariable> XPlus = new ArrayList<ContinuousVariable>(X);
        ArrayList<DiscreteVariable> APlus = new ArrayList<DiscreteVariable>(A);
        if (target instanceof ContinuousVariable) {
            XPlus.add((ContinuousVariable)target);
        } else if (target instanceof DiscreteVariable) {
            APlus.add((DiscreteVariable)target);
        }
        Ret ret1 = this.likelihoodJoint(XPlus, APlus, target, this.rows);
        Ret ret2 = this.likelihoodJoint(X, A, target, this.rows);
        return new Ret(ret1.getLik() - ret2.getLik(), ret1.getDof() - ret2.getDof());
    }

    public double getPenaltyDiscount() {
        return this.penaltyDiscount;
    }

    public void setPenaltyDiscount(double penaltyDiscount) {
        this.penaltyDiscount = penaltyDiscount;
    }

    public void setDiscretize(boolean discretize) {
        this.discretize = discretize;
    }

    public void setNumCategoriesToDiscretize(int numCategoriesToDiscretize) {
        this.numCategoriesToDiscretize = numCategoriesToDiscretize;
    }

    private Ret likelihoodJoint(List<ContinuousVariable> X, List<DiscreteVariable> A, Node target, List<Integer> rows) {
        A = new ArrayList<DiscreteVariable>(A);
        X = new ArrayList<ContinuousVariable>(X);
        if (this.discretize && target instanceof DiscreteVariable) {
            for (ContinuousVariable x : new ArrayList<ContinuousVariable>(X)) {
                Node variable = this.dataSet.getVariable(x.getName());
                if (variable == null) continue;
                A.add((DiscreteVariable)variable);
                X.remove(x);
            }
        }
        int k = X.size();
        int[] continuousCols = new int[k];
        for (int j = 0; j < k; ++j) {
            continuousCols[j] = this.nodesHash.get(X.get(j));
        }
        double c1 = 0.0;
        double c2 = 0.0;
        List<List<Integer>> cells = this.partition(A, rows);
        for (List<Integer> cell : cells) {
            int a = cell.size();
            if (a == 0) continue;
            if (A.size() > 0) {
                c1 += (double)a * this.multinomialLikelihood(a, rows.size());
            }
            if (X.size() <= 0) continue;
            try {
                double gl = this.gaussianLikelihood(k, this.cov(this.getSubsample(continuousCols, cell)));
                if (Double.isNaN(gl)) continue;
                c2 += (double)a * gl;
            }
            catch (Exception exception) {}
        }
        double lnL = c1 + c2;
        int dof = this.f(A) * this.h(X) + this.f(A);
        return new Ret(lnL, dof);
    }

    private double multinomialLikelihood(int a, int N) {
        return FastMath.log((double)a / (double)N);
    }

    private double gaussianLikelihood(int k, Matrix sigma) {
        return -0.5 * FastMath.log(sigma.det()) - 0.5 * (double)k * (1.0 + LOG2PI);
    }

    private Matrix cov(Matrix x) {
        return new Matrix(new Covariance(x.toArray(), true).getCovarianceMatrix().getData());
    }

    private Matrix getSubsample(int[] continuousCols, List<Integer> cell) {
        Matrix subset = new Matrix(cell.size(), continuousCols.length);
        for (int i = 0; i < cell.size(); ++i) {
            for (int j = 0; j < continuousCols.length; ++j) {
                subset.set(i, j, this.continuousData[continuousCols[j]][cell.get(i)]);
            }
        }
        return subset;
    }

    private int f(List<DiscreteVariable> A) {
        int f = 1;
        for (DiscreteVariable V : A) {
            f *= V.getNumCategories();
        }
        return f;
    }

    private int h(List<ContinuousVariable> X) {
        int p = X.size();
        return p * (p + 1) / 2;
    }

    private List<List<Integer>> partition(List<DiscreteVariable> discrete_parents, List<Integer> rows) {
        ArrayList<List<Integer>> cells = new ArrayList<List<Integer>>();
        HashMap keys = new HashMap();
        for (int i : rows) {
            ArrayList<Integer> key = new ArrayList<Integer>();
            for (DiscreteVariable discrete_parent : discrete_parents) {
                key.add(this.dataSet.getInt(i, this.dataSet.getColumn(discrete_parent)));
            }
            if (!keys.containsKey(key)) {
                keys.put(key, cells.size());
                cells.add((Integer)keys.get(key), new ArrayList());
            }
            ((List)cells.get((Integer)keys.get(key))).add(i);
        }
        return cells;
    }

    public static class Ret {
        private final double lik;
        private final int dof;

        private Ret(double lik, int dof) {
            this.lik = lik;
            this.dof = dof;
        }

        public double getLik() {
            return this.lik;
        }

        public int getDof() {
            return this.dof;
        }

        public String toString() {
            return "lik = " + this.lik + " dof = " + this.dof;
        }
    }
}

