/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.util.Matrix;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections4.map.HashedMap;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.util.FastMath;

public class DegenerateGaussianScore
implements Score {
    private final BoxDataSet ddata;
    private final DataSet dataSet;
    private final List<Node> variables;
    private double penaltyDiscount = 1.0;
    private double structurePrior;
    private final Map<Integer, List<Integer>> embedding;
    private static final double L2PE = FastMath.log(17.079468445347132);
    private final Map<Node, Integer> nodesHash;

    public DegenerateGaussianScore(DataSet dataSet) {
        int j;
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.dataSet = dataSet;
        this.variables = dataSet.getVariables();
        int n = dataSet.getNumRows();
        this.embedding = new HashMap<Integer, List<Integer>>();
        ArrayList<Node> A = new ArrayList<Node>();
        ArrayList<double[]> B = new ArrayList<double[]>();
        int index = 0;
        int i = 0;
        for (int i_ = 0; i_ < this.variables.size(); ++i_) {
            Node v = this.variables.get(i_);
            if (v instanceof DiscreteVariable) {
                HashMap keys = new HashMap();
                HashMap keysReverse = new HashMap();
                for (j = 0; j < n; ++j) {
                    ArrayList<Integer> key = new ArrayList<Integer>();
                    key.add(this.dataSet.getInt(j, i_));
                    if (!keys.containsKey(key)) {
                        keys.put(key, i);
                        keysReverse.put(i, key);
                        ContinuousVariable v_ = new ContinuousVariable("V__" + ++index);
                        A.add(v_);
                        B.add(new double[n]);
                        ++i;
                    }
                    ((double[])B.get((int)((Integer)keys.get(key)).intValue()))[j] = 1.0;
                }
                keys.remove(keysReverse.get(--i));
                A.remove(i);
                B.remove(i);
                this.embedding.put(i_, new ArrayList(keys.values()));
                continue;
            }
            A.add(v);
            double[] b = new double[n];
            for (int j2 = 0; j2 < n; ++j2) {
                b[j2] = this.dataSet.getDouble(j2, i_);
            }
            B.add(b);
            ArrayList<Integer> index2 = new ArrayList<Integer>();
            index2.add(i);
            this.embedding.put(i_, index2);
            ++i;
        }
        double[][] B_ = new double[n][B.size()];
        for (int j3 = 0; j3 < B.size(); ++j3) {
            for (int k = 0; k < n; ++k) {
                B_[k][j3] = ((double[])B.get(j3))[k];
            }
        }
        BlockRealMatrix D = new BlockRealMatrix(B_);
        this.ddata = new BoxDataSet(new DoubleDataBox(D.getData()), A);
        this.nodesHash = new HashedMap<Node, Integer>();
        List<Node> variables = dataSet.getVariables();
        for (j = 0; j < variables.size(); ++j) {
            this.nodesHash.put(variables.get(j), j);
        }
    }

    @Override
    public double localScore(int i, int ... parents) {
        int i_;
        List<Integer> rows = this.getRows(i, parents);
        int N = rows.size();
        ArrayList B = new ArrayList();
        ArrayList A = new ArrayList(this.embedding.get(i));
        for (int i_2 : parents) {
            B.addAll(this.embedding.get(i_2));
        }
        int[] A_ = new int[A.size() + B.size()];
        int[] B_ = new int[B.size()];
        for (i_ = 0; i_ < A.size(); ++i_) {
            A_[i_] = (Integer)A.get(i_);
        }
        for (i_ = 0; i_ < B.size(); ++i_) {
            A_[A.size() + i_] = (Integer)B.get(i_);
            B_[i_] = (Integer)B.get(i_);
        }
        int dof = (A_.length * (A_.length + 1) - B_.length * (B_.length + 1)) / 2;
        double ldetA = FastMath.log(this.getCov(rows, A_).det());
        double ldetB = FastMath.log(this.getCov(rows, B_).det());
        double lik = (double)N * (ldetB - ldetA + L2PE * (double)(B_.length - A_.length));
        double score = 2.0 * lik + 2.0 * this.calculateStructurePrior(parents.length) - (double)dof * this.getPenaltyDiscount() * FastMath.log(N);
        if (Double.isNaN(score) || Double.isInfinite(score)) {
            return Double.NaN;
        }
        return score;
    }

    private double calculateStructurePrior(int k) {
        if (this.structurePrior <= 0.0) {
            return 0.0;
        }
        double n = this.variables.size() - 1;
        double p = this.structurePrior / n;
        return (double)k * FastMath.log(p) + (n - (double)k) * FastMath.log(1.0 - p);
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        return this.localScore(y, this.append(z, x)) - this.localScore(y, z);
    }

    @Override
    public double localScoreDiff(int x, int y) {
        return this.localScore(y, x) - this.localScore(y);
    }

    private int[] append(int[] parents, int extra) {
        int[] all = new int[parents.length + 1];
        System.arraycopy(parents, 0, all, 0, parents.length);
        all[parents.length] = extra;
        return all;
    }

    @Override
    public double localScore(int i, int parent) {
        return this.localScore(i, new int[]{parent});
    }

    @Override
    public double localScore(int i) {
        return this.localScore(i, new int[0]);
    }

    @Override
    public int getSampleSize() {
        return this.dataSet.getNumRows();
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > 0.0;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public Node getVariable(String targetName) {
        for (Node node : this.variables) {
            if (!node.getName().equals(targetName)) continue;
            return node;
        }
        return null;
    }

    @Override
    public int getMaxDegree() {
        return (int)FastMath.ceil(FastMath.log(this.dataSet.getNumRows()));
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        return false;
    }

    public double getPenaltyDiscount() {
        return this.penaltyDiscount;
    }

    public void setPenaltyDiscount(double penaltyDiscount) {
        this.penaltyDiscount = penaltyDiscount;
    }

    public double getStructurePrior() {
        return this.structurePrior;
    }

    public void setStructurePrior(double structurePrior) {
        this.structurePrior = structurePrior;
    }

    @Override
    public String toString() {
        DecimalFormat nf = new DecimalFormat("0.00");
        return "Degenerate Gaussian Score Penalty " + nf.format(this.penaltyDiscount);
    }

    private Matrix getCov(List<Integer> rows, int[] cols) {
        if (rows.isEmpty()) {
            return new Matrix(0, 0);
        }
        Matrix cov = new Matrix(cols.length, cols.length);
        for (int i = 0; i < cols.length; ++i) {
            for (int j = 0; j < cols.length; ++j) {
                double mui = 0.0;
                double muj = 0.0;
                for (int k : rows) {
                    mui += this.ddata.getDouble(k, cols[i]);
                    muj += this.ddata.getDouble(k, cols[j]);
                }
                mui /= (double)(rows.size() - 1);
                muj /= (double)(rows.size() - 1);
                double _cov = 0.0;
                for (int k : rows) {
                    _cov += (this.ddata.getDouble(k, cols[i]) - mui) * (this.ddata.getDouble(k, cols[j]) - muj);
                }
                double mean = _cov / (double)rows.size();
                cov.set(i, j, mean);
            }
        }
        return cov;
    }

    private List<Integer> getRows(int i, int[] parents) {
        ArrayList<Integer> rows = new ArrayList<Integer>();
        block0: for (int k = 0; k < this.dataSet.getNumRows(); ++k) {
            Node ii = this.variables.get(i);
            ArrayList A = new ArrayList(this.embedding.get(this.nodesHash.get(ii)));
            Iterator iterator = A.iterator();
            while (iterator.hasNext()) {
                int j = (Integer)iterator.next();
                if (!Double.isNaN(this.ddata.getDouble(k, j))) continue;
                continue block0;
            }
            for (Object ignored : (Iterator)parents) {
                Node pp = this.variables.get(i);
                ArrayList AA = new ArrayList(this.embedding.get(this.nodesHash.get(pp)));
                Iterator iterator2 = AA.iterator();
                while (iterator2.hasNext()) {
                    int j = (Integer)iterator2.next();
                    if (!Double.isNaN(this.ddata.getDouble(k, j))) continue;
                    continue block0;
                }
            }
            rows.add(k);
        }
        return rows;
    }
}

