/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.ConditionalGaussianLikelihood;
import edu.cmu.tetrad.search.Score;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class ConditionalGaussianScore
implements Score {
    private final DataSet dataSet;
    private final List<Node> variables;
    private final ConditionalGaussianLikelihood likelihood;
    private double penaltyDiscount;
    private int numCategoriesToDiscretize = 3;
    private final double structurePrior;

    public ConditionalGaussianScore(DataSet dataSet, double penaltyDiscount, double structurePrior, boolean discretize) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.dataSet = dataSet;
        this.variables = dataSet.getVariables();
        this.penaltyDiscount = penaltyDiscount;
        this.structurePrior = structurePrior;
        this.likelihood = new ConditionalGaussianLikelihood(dataSet);
        this.likelihood.setNumCategoriesToDiscretize(this.numCategoriesToDiscretize);
        this.likelihood.setPenaltyDiscount(penaltyDiscount);
        this.likelihood.setDiscretize(discretize);
    }

    @Override
    public double localScore(int i, int ... parents) {
        List<Integer> rows = this.getRows(i, parents);
        this.likelihood.setRows(rows);
        ConditionalGaussianLikelihood.Ret ret = this.likelihood.getLikelihood(i, parents);
        double lik = ret.getLik();
        int k = ret.getDof();
        double score = 2.0 * (lik + this.getStructurePrior(parents)) - this.getPenaltyDiscount() * (double)k * FastMath.log(rows.size());
        if (Double.isNaN(score) || Double.isInfinite(score)) {
            return Double.NEGATIVE_INFINITY;
        }
        return score;
    }

    private List<Integer> getRows(int i, int[] parents) {
        ArrayList<Integer> rows = new ArrayList<Integer>();
        block0: for (int k = 0; k < this.dataSet.getNumRows(); ++k) {
            if (this.variables.get(i) instanceof DiscreteVariable) {
                if (this.dataSet.getInt(k, i) == -99) {
                    continue;
                }
            } else if (this.variables.get(i) instanceof ContinuousVariable) {
                this.dataSet.getInt(k, i);
            }
            for (int p : parents) {
                if (this.variables.get(i) instanceof DiscreteVariable) {
                    if (this.dataSet.getInt(k, p) != -99) continue;
                    continue block0;
                }
                if (!(this.variables.get(i) instanceof ContinuousVariable)) continue;
                this.dataSet.getInt(k, p);
            }
            rows.add(k);
        }
        return rows;
    }

    private double getStructurePrior(int[] parents) {
        if (this.structurePrior <= 0.0) {
            return 0.0;
        }
        int k = parents.length;
        double n = this.dataSet.getNumColumns() - 1;
        double p = this.structurePrior / n;
        return (double)k * FastMath.log(p) + (n - (double)k) * FastMath.log(1.0 - p);
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        return this.localScore(y, this.append(z, x)) - this.localScore(y, z);
    }

    @Override
    public double localScoreDiff(int x, int y) {
        return this.localScore(y, x) - this.localScore(y);
    }

    private int[] append(int[] parents, int extra) {
        int[] all = new int[parents.length + 1];
        System.arraycopy(parents, 0, all, 0, parents.length);
        all[parents.length] = extra;
        return all;
    }

    @Override
    public double localScore(int i, int parent) {
        return this.localScore(i, new int[]{parent});
    }

    @Override
    public double localScore(int i) {
        return this.localScore(i, new int[0]);
    }

    @Override
    public int getSampleSize() {
        return this.dataSet.getNumRows();
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > 0.0;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public Node getVariable(String targetName) {
        for (Node node : this.variables) {
            if (!node.getName().equals(targetName)) continue;
            return node;
        }
        return null;
    }

    @Override
    public int getMaxDegree() {
        return (int)FastMath.ceil(FastMath.log(this.dataSet.getNumRows()));
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        return false;
    }

    public double getPenaltyDiscount() {
        return this.penaltyDiscount;
    }

    public void setPenaltyDiscount(double penaltyDiscount) {
        this.penaltyDiscount = penaltyDiscount;
    }

    public void setNumCategoriesToDiscretize(int numCategoriesToDiscretize) {
        this.numCategoriesToDiscretize = numCategoriesToDiscretize;
    }

    @Override
    public String toString() {
        DecimalFormat nf = new DecimalFormat("0.00");
        return "Conditional Gaussian Score Penalty " + nf.format(this.penaltyDiscount);
    }
}

