/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.Model;
import de.bwaldvogel.liblinear.Parameter;
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.SolverType;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.AdLeafTree;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetrad.util.Vector;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;

public class MNLRLikelihood {
    private final DataSet dataSet;
    private final List<Node> variables;
    private final Map<Node, Integer> nodesHash;
    private final double[][] continuousData;
    private final int[][] discreteData;
    private final AdLeafTree adTree;
    private final int fDegree;
    private final double structurePrior;
    private final PrintStream original = System.out;
    private final PrintStream nullout = new PrintStream(new OutputStream(){

        @Override
        public void write(int b) {
        }
    });

    public MNLRLikelihood(DataSet dataSet, double structurePrior, int fDegree) {
        Node v;
        int j;
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.dataSet = dataSet;
        this.variables = dataSet.getVariables();
        this.structurePrior = structurePrior;
        this.fDegree = fDegree;
        this.continuousData = new double[dataSet.getNumColumns()][];
        this.discreteData = new int[dataSet.getNumColumns()][];
        for (j = 0; j < dataSet.getNumColumns(); ++j) {
            int i;
            Object[] col;
            v = dataSet.getVariable(j);
            if (v instanceof ContinuousVariable) {
                col = new double[dataSet.getNumRows()];
                for (i = 0; i < dataSet.getNumRows(); ++i) {
                    col[i] = dataSet.getDouble(i, j);
                }
                this.continuousData[j] = col;
                continue;
            }
            if (!(v instanceof DiscreteVariable)) continue;
            col = new int[dataSet.getNumRows()];
            for (i = 0; i < dataSet.getNumRows(); ++i) {
                col[i] = dataSet.getInt(i, j);
            }
            this.discreteData[j] = (int[])col;
        }
        this.nodesHash = new HashMap<Node, Integer>();
        for (j = 0; j < dataSet.getNumColumns(); ++j) {
            v = dataSet.getVariable(j);
            this.nodesHash.put(v, j);
        }
        this.adTree = new AdLeafTree(dataSet);
    }

    private double multipleRegression(Vector Y, Matrix X) {
        double lik;
        Vector r;
        int n = X.rows();
        try {
            Matrix Xt = X.transpose();
            Matrix XtX = Xt.times(X);
            r = X.times(XtX.inverse().times(Xt.times(Y))).minus(Y);
        }
        catch (Exception e) {
            Vector ones = new Vector(n);
            for (int i = 0; i < n; ++i) {
                ones.set(i, 1.0);
            }
            r = ones.scalarMult(ones.dotProduct(Y) / (double)n).minus(Y);
        }
        double sigma2 = r.dotProduct(r) / (double)n;
        if (sigma2 <= 0.0) {
            Vector ones = new Vector(n);
            for (int i = 0; i < n; ++i) {
                ones.set(i, 1.0);
            }
            r = ones.scalarMult(ones.dotProduct(Y) / (double)FastMath.max(n, 2)).minus(Y);
            sigma2 = r.dotProduct(r) / (double)n;
        }
        if (Double.isInfinite(lik = -((double)n / 2.0) * (FastMath.log(Math.PI * 2) + FastMath.log(sigma2) + 1.0)) || Double.isNaN(lik)) {
            System.out.println(lik);
        }
        return lik;
    }

    public double getLik(int child_index, int[] parents) {
        double lik = 0.0;
        Node c = this.variables.get(child_index);
        ArrayList<ContinuousVariable> continuous_parents = new ArrayList<ContinuousVariable>();
        ArrayList<DiscreteVariable> discrete_parents = new ArrayList<DiscreteVariable>();
        for (int p : parents) {
            Node parent = this.variables.get(p);
            if (parent instanceof ContinuousVariable) {
                continuous_parents.add((ContinuousVariable)parent);
                continue;
            }
            discrete_parents.add((DiscreteVariable)parent);
        }
        int p = continuous_parents.size();
        List<List<Integer>> cells = this.adTree.getCellLeaves(discrete_parents);
        int[] continuousCols = new int[p];
        for (int j = 0; j < p; ++j) {
            continuousCols[j] = this.nodesHash.get(continuous_parents.get(j));
        }
        for (List<Integer> cell : cells) {
            int i;
            TetradSerializable target;
            int r = cell.size();
            if (r <= 1) continue;
            double[] mean = new double[p];
            double[] var = new double[p];
            for (int i2 = 0; i2 < p; ++i2) {
                for (Integer integer : cell) {
                    int n = i2;
                    mean[n] = mean[n] + this.continuousData[continuousCols[i2]][integer];
                    int n2 = i2;
                    var[n2] = var[n2] + FastMath.pow(this.continuousData[continuousCols[i2]][integer], 2);
                }
                int n = i2;
                mean[n] = mean[n] / (double)r;
                int n3 = i2;
                var[n3] = var[n3] / (double)r;
                int n4 = i2;
                var[n4] = var[n4] - FastMath.pow(mean[i2], 2);
                var[i2] = FastMath.sqrt(var[i2]);
                if (!Double.isNaN(var[i2])) continue;
                System.out.println(var[i2]);
            }
            int degree = this.fDegree;
            if (this.fDegree < 1) {
                degree = (int)FastMath.floor(FastMath.log(r));
            }
            Matrix subset = new Matrix(r, p * degree + 1);
            for (int i3 = 0; i3 < r; ++i3) {
                subset.set(i3, p * degree, 1.0);
                for (int j = 0; j < p; ++j) {
                    for (int d = 0; d < degree; ++d) {
                        subset.set(i3, p * d + j, FastMath.pow((this.continuousData[continuousCols[j]][cell.get(i3)] - mean[j]) / var[j], d + 1));
                    }
                }
            }
            if (c instanceof ContinuousVariable) {
                target = new Vector(r);
                for (i = 0; i < r; ++i) {
                    ((Vector)target).set(i, this.continuousData[child_index][cell.get(i)]);
                }
                lik += this.multipleRegression((Vector)target, subset);
                continue;
            }
            target = new Matrix(r, ((DiscreteVariable)c).getNumCategories());
            for (i = 0; i < r; ++i) {
                for (int j = 0; j < ((DiscreteVariable)c).getNumCategories(); ++j) {
                    ((Matrix)target).set(i, j, -1.0);
                }
                ((Matrix)target).set(i, this.discreteData[child_index][cell.get(i)], 1.0);
            }
            lik += this.MultinomialLogisticRegression((Matrix)target, subset);
        }
        return lik;
    }

    public double getDoF(int child_index, int[] parents) {
        double dof = 0.0;
        Node c = this.variables.get(child_index);
        ArrayList<ContinuousVariable> continuous_parents = new ArrayList<ContinuousVariable>();
        ArrayList<DiscreteVariable> discrete_parents = new ArrayList<DiscreteVariable>();
        for (int p : parents) {
            Node parent = this.variables.get(p);
            if (parent instanceof ContinuousVariable) {
                continuous_parents.add((ContinuousVariable)parent);
                continue;
            }
            discrete_parents.add((DiscreteVariable)parent);
        }
        List<List<Integer>> cells = this.adTree.getCellLeaves(discrete_parents);
        for (List<Integer> cell : cells) {
            int r = cell.size();
            if (r <= 0) continue;
            int degree = this.fDegree;
            if (this.fDegree < 1) {
                degree = (int)FastMath.floor(FastMath.log(r));
            }
            if (c instanceof ContinuousVariable) {
                dof += (double)(degree * continuous_parents.size() + 1);
                continue;
            }
            dof += (double)((degree * continuous_parents.size() + 1) * (((DiscreteVariable)c).getNumCategories() - 1));
        }
        return dof;
    }

    public double getStructurePrior(int k) {
        if (this.structurePrior < 0.0) {
            return this.getEBICprior();
        }
        double n = this.dataSet.getNumColumns() - 1;
        double p = this.structurePrior / n;
        if (this.structurePrior == 0.0) {
            return 0.0;
        }
        return (double)k * FastMath.log(p) + (n - (double)k) * FastMath.log(1.0 - p);
    }

    public double getEBICprior() {
        double n = this.dataSet.getNumColumns();
        double gamma = -this.structurePrior;
        return gamma * FastMath.log(n);
    }

    private double MultinomialLogisticRegression(Matrix targets, Matrix subset) {
        Problem problem = new Problem();
        problem.l = targets.rows();
        problem.n = subset.columns();
        problem.x = new FeatureNode[problem.l][problem.n];
        problem.bias = 0.0;
        for (int i = 0; i < problem.l; ++i) {
            for (int j = 0; j < problem.n; ++j) {
                problem.x[i][j] = new FeatureNode(j + 1, subset.get(i, j));
            }
        }
        SolverType solver = SolverType.L2R_LR;
        double C = 1.0;
        double eps = 1.0E-4;
        Parameter parameter = new Parameter(solver, 1.0, 1.0E-4);
        ArrayList<Model> models = new ArrayList<Model>();
        double lik = 0.0;
        for (int i = 0; i < targets.columns(); ++i) {
            System.setOut(this.nullout);
            problem.y = targets.getColumn(i).toArray();
            models.add(i, Linear.train(problem, parameter));
            System.setOut(this.original);
        }
        for (int j = 0; j < problem.l; ++j) {
            double num = 0.0;
            double den = 0.0;
            for (int i = 0; i < targets.columns(); ++i) {
                double[] p = new double[((Model)models.get(i)).getNrClass()];
                Linear.predictProbability((Model)models.get(i), problem.x[j], p);
                if (targets.get(j, i) == 1.0) {
                    num = p[0];
                    den += p[0];
                    continue;
                }
                if (p.length <= 1) continue;
                den += p[0];
            }
            lik += FastMath.log(num / den);
        }
        return lik;
    }
}

