/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.Discretizer;
import edu.cmu.tetrad.data.VerticalIntDataBox;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.AdLeafTree;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;

public class MVPLikelihood {
    private final DataSet dataSet;
    private final List<Node> variables;
    private List<Node> discreteVariables;
    private final Map<Node, Integer> nodesHash;
    private final double[][] continuousData;
    private final int[][] discreteData;
    private final AdLeafTree adTree;
    private final int fDegree;
    private final double structurePrior;
    private final boolean discretize;

    public MVPLikelihood(DataSet dataSet, double structurePrior, int fDegree, boolean discretize) {
        int i;
        Object[] col;
        Node v;
        int j;
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.dataSet = dataSet;
        this.variables = dataSet.getVariables();
        this.structurePrior = structurePrior;
        this.fDegree = fDegree;
        this.discretize = discretize;
        this.continuousData = new double[dataSet.getNumColumns()][];
        this.discreteData = new int[dataSet.getNumColumns()][];
        for (j = 0; j < dataSet.getNumColumns(); ++j) {
            v = dataSet.getVariable(j);
            if (v instanceof ContinuousVariable) {
                col = new double[dataSet.getNumRows()];
                for (i = 0; i < dataSet.getNumRows(); ++i) {
                    col[i] = dataSet.getDouble(i, j);
                }
                this.continuousData[j] = col;
                continue;
            }
            if (!(v instanceof DiscreteVariable)) continue;
            col = new int[dataSet.getNumRows()];
            for (i = 0; i < dataSet.getNumRows(); ++i) {
                col[i] = dataSet.getInt(i, j);
            }
            this.discreteData[j] = (int[])col;
        }
        this.nodesHash = new HashMap<Node, Integer>();
        for (j = 0; j < dataSet.getNumColumns(); ++j) {
            v = dataSet.getVariable(j);
            this.nodesHash.put(v, j);
        }
        if (discretize) {
            DataSet discreteDataSet = this.useErsatzVariables();
            this.discreteVariables = discreteDataSet.getVariables();
            this.adTree = new AdLeafTree(discreteDataSet);
            for (int j2 = 0; j2 < dataSet.getNumColumns(); ++j2) {
                col = new int[discreteDataSet.getNumRows()];
                for (i = 0; i < discreteDataSet.getNumRows(); ++i) {
                    col[i] = discreteDataSet.getInt(i, j2);
                }
                this.discreteData[j2] = (int[])col;
            }
        } else {
            this.adTree = new AdLeafTree(dataSet);
        }
    }

    private double multipleRegression(Vector Y, Matrix X) {
        double lik;
        Vector r;
        int n = X.rows();
        if (X.columns() >= n) {
            Vector ones = new Vector(n);
            for (int i = 0; i < n; ++i) {
                ones.set(i, 1.0);
            }
            r = ones.scalarMult(ones.dotProduct(Y) / (double)n).minus(Y);
        } else {
            try {
                Matrix Xt = X.transpose();
                Matrix XtX = Xt.times(X);
                r = X.times(XtX.inverse().times(Xt.times(Y))).minus(Y);
            }
            catch (Exception e) {
                Vector ones = new Vector(n);
                for (int i = 0; i < n; ++i) {
                    ones.set(i, 1.0);
                }
                r = ones.scalarMult(ones.dotProduct(Y) / (double)n).minus(Y);
            }
        }
        double sigma2 = r.dotProduct(r) / (double)n;
        if (sigma2 < 0.0) {
            Vector ones = new Vector(n);
            for (int i = 0; i < n; ++i) {
                ones.set(i, 1.0);
            }
            r = ones.scalarMult(ones.dotProduct(Y) / (double)FastMath.max(n, 2)).minus(Y);
            sigma2 = r.dotProduct(r) / (double)n;
            lik = -((double)n / 2.0) * (FastMath.log(Math.PI * 2) + FastMath.log(sigma2) + 1.0);
        } else {
            lik = sigma2 == 0.0 ? 0.0 : -((double)n / 2.0) * (FastMath.log(Math.PI * 2) + FastMath.log(sigma2) + 1.0);
        }
        if (Double.isInfinite(lik) || Double.isNaN(lik)) {
            System.out.println(lik);
        }
        return lik;
    }

    private double approxMultinomialRegression(Matrix Y, Matrix X) {
        Matrix P;
        int n = X.rows();
        int d = Y.columns();
        double lik = 0.0;
        if (d >= n || X.columns() >= n) {
            Matrix ones = new Matrix(n, 1);
            for (int i = 0; i < n; ++i) {
                ones.set(i, 0, 1.0);
            }
            P = ones.times(ones.transpose().times(Y).scalarMult(1.0 / (double)n));
        } else {
            try {
                Matrix Xt = X.transpose();
                Matrix XtX = Xt.times(X);
                P = X.times(XtX.inverse().times(Xt.times(Y)));
            }
            catch (Exception e) {
                Matrix ones = new Matrix(n, 1);
                for (int i = 0; i < n; ++i) {
                    ones.set(i, 0, 1.0);
                }
                P = ones.times(ones.transpose().times(Y).scalarMult(1.0 / (double)n));
            }
            for (int i = 0; i < n; ++i) {
                int j;
                double min = 1.0;
                double center = 1.0 / (double)d;
                double bound = 1.0 / (double)n;
                for (j = 0; j < d; ++j) {
                    min = FastMath.min(min, P.get(i, j));
                }
                if (X.columns() <= 1 || !(min < bound)) continue;
                min = (bound - center) / (min - center);
                for (j = 0; j < d; ++j) {
                    P.set(i, j, min * P.get(i, j) + center * (1.0 - min));
                }
            }
        }
        for (int i = 0; i < n; ++i) {
            lik += FastMath.log(P.getRow(i).dotProduct(Y.getRow(i)));
        }
        if (Double.isInfinite(lik) || Double.isNaN(lik)) {
            System.out.println(lik);
        }
        return lik;
    }

    public double getLik(int child_index, int[] parents) {
        Node parent;
        double lik = 0.0;
        Node c = this.variables.get(child_index);
        ArrayList<ContinuousVariable> continuous_parents = new ArrayList<ContinuousVariable>();
        ArrayList<DiscreteVariable> discrete_parents = new ArrayList<DiscreteVariable>();
        if (c instanceof DiscreteVariable && this.discretize) {
            for (int p : parents) {
                parent = this.discreteVariables.get(p);
                discrete_parents.add((DiscreteVariable)parent);
            }
        } else {
            for (int p : parents) {
                parent = this.variables.get(p);
                if (parent instanceof ContinuousVariable) {
                    continuous_parents.add((ContinuousVariable)parent);
                    continue;
                }
                discrete_parents.add((DiscreteVariable)parent);
            }
        }
        int p = continuous_parents.size();
        List<List<Integer>> cells = this.adTree.getCellLeaves(discrete_parents);
        int[] continuousCols = new int[p];
        for (int j = 0; j < p; ++j) {
            continuousCols[j] = this.nodesHash.get(continuous_parents.get(j));
        }
        for (List<Integer> cell : cells) {
            int i;
            TetradSerializable target;
            int r = cell.size();
            if (r <= 1) continue;
            double[] mean = new double[p];
            double[] var = new double[p];
            for (int i2 = 0; i2 < p; ++i2) {
                for (Integer integer : cell) {
                    int n = i2;
                    mean[n] = mean[n] + this.continuousData[continuousCols[i2]][integer];
                    int n2 = i2;
                    var[n2] = var[n2] + FastMath.pow(this.continuousData[continuousCols[i2]][integer], 2);
                }
                int n = i2;
                mean[n] = mean[n] / (double)r;
                int n3 = i2;
                var[n3] = var[n3] / (double)r;
                int n4 = i2;
                var[n4] = var[n4] - FastMath.pow(mean[i2], 2);
                var[i2] = FastMath.sqrt(var[i2]);
                if (!Double.isNaN(var[i2])) continue;
                System.out.println(var[i2]);
            }
            int degree = this.fDegree;
            if (this.fDegree < 1) {
                degree = (int)FastMath.floor(FastMath.log(r));
            }
            Matrix subset = new Matrix(r, p * degree + 1);
            for (int i3 = 0; i3 < r; ++i3) {
                subset.set(i3, p * degree, 1.0);
                for (int j = 0; j < p; ++j) {
                    for (int d = 0; d < degree; ++d) {
                        subset.set(i3, p * d + j, FastMath.pow((this.continuousData[continuousCols[j]][cell.get(i3)] - mean[j]) / var[j], d + 1));
                    }
                }
            }
            if (c instanceof ContinuousVariable) {
                target = new Vector(r);
                for (i = 0; i < r; ++i) {
                    ((Vector)target).set(i, this.continuousData[child_index][cell.get(i)]);
                }
                lik += this.multipleRegression((Vector)target, subset);
                continue;
            }
            assert (c instanceof DiscreteVariable);
            target = new Matrix(r, ((DiscreteVariable)c).getNumCategories());
            for (i = 0; i < r; ++i) {
                ((Matrix)target).set(i, this.discreteData[child_index][cell.get(i)], 1.0);
            }
            lik += this.approxMultinomialRegression((Matrix)target, subset);
        }
        return lik;
    }

    public double getDoF(int child_index, int[] parents) {
        Node parent;
        double dof = 0.0;
        Node c = this.variables.get(child_index);
        ArrayList<ContinuousVariable> continuous_parents = new ArrayList<ContinuousVariable>();
        ArrayList<DiscreteVariable> discrete_parents = new ArrayList<DiscreteVariable>();
        if (c instanceof DiscreteVariable && this.discretize) {
            for (int p : parents) {
                parent = this.discreteVariables.get(p);
                discrete_parents.add((DiscreteVariable)parent);
            }
        } else {
            for (int p : parents) {
                parent = this.variables.get(p);
                if (parent instanceof ContinuousVariable) {
                    continuous_parents.add((ContinuousVariable)parent);
                    continue;
                }
                discrete_parents.add((DiscreteVariable)parent);
            }
        }
        List<List<Integer>> cells = this.adTree.getCellLeaves(discrete_parents);
        for (List<Integer> cell : cells) {
            int r = cell.size();
            if (r <= 0) continue;
            int degree = this.fDegree;
            if (this.fDegree < 1) {
                degree = (int)FastMath.floor(FastMath.log(r));
            }
            if (c instanceof ContinuousVariable) {
                dof += (double)(degree * continuous_parents.size() + 1);
                continue;
            }
            assert (c instanceof DiscreteVariable);
            dof += (double)((degree * continuous_parents.size() + 1) * (((DiscreteVariable)c).getNumCategories() - 1));
        }
        return dof;
    }

    public double getStructurePrior(int k) {
        if (this.structurePrior < 0.0) {
            return this.getEBICprior();
        }
        double n = this.dataSet.getNumColumns() - 1;
        double p = this.structurePrior / n;
        if (this.structurePrior == 0.0) {
            return 0.0;
        }
        return (double)k * FastMath.log(p) + (n - (double)k) * FastMath.log(1.0 - p);
    }

    public double getEBICprior() {
        double n = this.dataSet.getNumColumns();
        double gamma = -this.structurePrior;
        return gamma * FastMath.log(n);
    }

    private DataSet useErsatzVariables() {
        ArrayList<Node> nodes = new ArrayList<Node>();
        int numCategories = 3;
        for (Node x : this.variables) {
            if (x instanceof ContinuousVariable) {
                nodes.add(new DiscreteVariable(x.getName(), numCategories));
                continue;
            }
            nodes.add(x);
        }
        BoxDataSet replaced = new BoxDataSet(new VerticalIntDataBox(this.dataSet.getNumRows(), this.dataSet.getNumColumns()), nodes);
        for (int j = 0; j < this.variables.size(); ++j) {
            if (this.variables.get(j) instanceof DiscreteVariable) {
                for (int i = 0; i < this.dataSet.getNumRows(); ++i) {
                    replaced.setInt(i, j, this.dataSet.getInt(i, j));
                }
                continue;
            }
            double[] column = this.continuousData[j];
            double[] breakpoints = Discretizer.getEqualFrequencyBreakPoints(column, numCategories);
            ArrayList<String> categoryNames = new ArrayList<String>();
            for (int i = 0; i < numCategories; ++i) {
                categoryNames.add("" + i);
            }
            Discretizer.Discretization d = Discretizer.discretize(column, breakpoints, this.variables.get(j).getName(), categoryNames);
            for (int i = 0; i < this.dataSet.getNumRows(); ++i) {
                replaced.setInt(i, j, d.getData()[i]);
            }
        }
        return replaced;
    }
}

