/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.score;

import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

public class SemBicScore
implements Score {
    private final int sampleSize;
    private final Map<Node, Integer> indexMap;
    private final double logN;
    private boolean calculateRowSubsets;
    private DataModel dataModel;
    private Matrix data;
    private ICovarianceMatrix covariances;
    private List<Node> variables;
    private boolean verbose;
    private double penaltyDiscount = 1.0;
    private double structurePrior;
    private Matrix matrix;
    private RuleType ruleType = RuleType.CHICKERING;

    public SemBicScore(ICovarianceMatrix covariances) {
        if (covariances == null) {
            throw new NullPointerException();
        }
        this.setCovariances(covariances);
        this.variables = covariances.getVariables();
        this.sampleSize = covariances.getSampleSize();
        this.indexMap = this.indexMap(this.variables);
        this.logN = FastMath.log(this.sampleSize);
    }

    public SemBicScore(DataSet dataSet, boolean precomputeCovariances) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.dataModel = dataSet;
        this.data = dataSet.getDoubleData();
        if (!dataSet.existsMissingValue()) {
            this.setCovariances(this.getCovarianceMatrix(dataSet, precomputeCovariances));
            this.variables = this.covariances.getVariables();
            this.sampleSize = this.covariances.getSampleSize();
            this.indexMap = this.indexMap(this.variables);
            this.calculateRowSubsets = false;
            this.logN = FastMath.log(this.sampleSize);
            return;
        }
        this.variables = dataSet.getVariables();
        this.sampleSize = dataSet.getNumRows();
        this.indexMap = this.indexMap(this.variables);
        this.calculateRowSubsets = true;
        this.logN = FastMath.log(this.sampleSize);
    }

    public static double getVarRy(int i, int[] parents, Matrix data, ICovarianceMatrix covariances, boolean calculateRowSubsets) throws SingularMatrixException {
        int[] all = SemBicScore.concat(i, parents);
        Matrix cov = SemBicScore.getCov(SemBicScore.getRows(i, parents, data, calculateRowSubsets), all, all, data, covariances);
        int[] pp = SemBicScore.indexedParents(parents);
        Matrix covxx = cov.getSelection(pp, pp);
        Matrix covxy = cov.getSelection(pp, new int[]{0});
        Matrix b = covxx.inverse().times(covxy);
        Matrix bStar = SemBicScore.bStar(b);
        return bStar.transpose().times(cov).times(bStar).get(0, 0);
    }

    @NotNull
    public static Matrix bStar(Matrix b) {
        Matrix byx = new Matrix(b.getNumRows() + 1, 1);
        byx.set(0, 0, 1.0);
        for (int j = 0; j < b.getNumRows(); ++j) {
            byx.set(j + 1, 0, -b.get(j, 0));
        }
        return byx;
    }

    private static int[] indexedParents(int[] parents) {
        int[] pp = new int[parents.length];
        for (int j = 0; j < pp.length; ++j) {
            pp[j] = j + 1;
        }
        return pp;
    }

    private static int[] concat(int i, int[] parents) {
        int[] all = new int[parents.length + 1];
        all[0] = i;
        System.arraycopy(parents, 0, all, 1, parents.length);
        return all;
    }

    private static Matrix getCov(List<Integer> rows, int[] _rows, int[] cols, Matrix data, ICovarianceMatrix covarianceMatrix) {
        if (rows == null) {
            return covarianceMatrix.getSelection(_rows, cols);
        }
        Matrix cov = new Matrix(_rows.length, cols.length);
        for (int i = 0; i < _rows.length; ++i) {
            for (int j = 0; j < cols.length; ++j) {
                double mui = 0.0;
                double muj = 0.0;
                for (int k : rows) {
                    mui += data.get(k, _rows[i]);
                    muj += data.get(k, cols[j]);
                }
                mui /= (double)(rows.size() - 1);
                muj /= (double)(rows.size() - 1);
                double _cov = 0.0;
                for (int k : rows) {
                    _cov += (data.get(k, _rows[i]) - mui) * (data.get(k, cols[j]) - muj);
                }
                double mean = _cov / (double)rows.size();
                cov.set(i, j, mean);
            }
        }
        return cov;
    }

    private static List<Integer> getRows(int i, int[] parents, Matrix data, boolean calculateRowSubsets) {
        if (!calculateRowSubsets) {
            return null;
        }
        ArrayList<Integer> rows = new ArrayList<Integer>();
        block0: for (int k = 0; k < data.getNumRows(); ++k) {
            if (Double.isNaN(data.get(k, i))) continue;
            for (int p : parents) {
                if (Double.isNaN(data.get(k, p))) continue block0;
            }
            rows.add(k);
        }
        return rows;
    }

    @NotNull
    private ICovarianceMatrix getCovarianceMatrix(DataSet dataSet, boolean precomputeCovariances) {
        return SimpleDataLoader.getCovarianceMatrix(dataSet, precomputeCovariances);
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        if (this.ruleType == RuleType.NANDY) {
            return this.nandyBic(x, y, z);
        }
        return this.localScore(y, this.append(z, x)) - this.localScore(y, z);
    }

    public double nandyBic(int x, int y, int[] z) {
        double sp1 = this.getStructurePrior(z.length + 1);
        double sp2 = this.getStructurePrior(z.length);
        Node _x = this.variables.get(x);
        Node _y = this.variables.get(y);
        List<Node> _z = this.getVariableList(z);
        List<Integer> rows = this.getRows(x, z);
        if (rows != null) {
            rows.retainAll((Collection)Objects.requireNonNull(this.getRows(y, z)));
        }
        double r = this.partialCorrelation(_x, _y, _z, rows);
        double c = this.getPenaltyDiscount();
        return (double)(-this.sampleSize) * FastMath.log(1.0 - r * r) - c * FastMath.log(this.sampleSize) - 2.0 * (sp1 - sp2);
    }

    @Override
    public double localScore(int i, int ... parents) {
        double lik;
        int k = parents.length;
        Arrays.sort(parents);
        try {
            double varey = SemBicScore.getVarRy(i, parents, this.data, this.covariances, this.calculateRowSubsets);
            lik = -((double)this.sampleSize / 2.0) * FastMath.log(varey);
        }
        catch (SingularMatrixException e) {
            System.out.println("Singularity encountered when scoring " + LogUtilsSearch.getScoreFact(i, parents, this.variables));
            return Double.NaN;
        }
        double c = this.getPenaltyDiscount();
        if (this.ruleType == RuleType.CHICKERING || this.ruleType == RuleType.NANDY) {
            double _score = lik - c * ((double)k / 2.0) * this.logN - this.getStructurePrior(k);
            if (Double.isNaN(_score) || Double.isInfinite(_score)) {
                return Double.NaN;
            }
            return _score;
        }
        throw new IllegalStateException("That rule type is not implemented: " + (Object)((Object)this.ruleType));
    }

    public double getPenaltyDiscount() {
        return this.penaltyDiscount;
    }

    public void setPenaltyDiscount(double penaltyDiscount) {
        this.penaltyDiscount = penaltyDiscount;
    }

    public double getStructurePrior() {
        return this.structurePrior;
    }

    public void setStructurePrior(double structurePrior) {
        this.structurePrior = structurePrior;
    }

    public ICovarianceMatrix getCovariances() {
        return this.covariances;
    }

    private void setCovariances(ICovarianceMatrix covariances) {
        this.covariances = covariances;
        this.matrix = this.covariances.getMatrix();
        this.dataModel = covariances;
    }

    @Override
    public int getSampleSize() {
        return this.sampleSize;
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > 0.0;
    }

    public DataModel getDataModel() {
        return this.dataModel;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    @Override
    public List<Node> getVariables() {
        return new ArrayList<Node>(this.variables);
    }

    public void setVariables(List<Node> variables) {
        if (this.covariances != null) {
            this.covariances.setVariables(variables);
        }
        this.variables = variables;
    }

    @Override
    public int getMaxDegree() {
        return (int)FastMath.ceil(FastMath.log(this.sampleSize));
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        int i = this.variables.indexOf(y);
        int[] k = new int[z.size()];
        for (int t = 0; t < z.size(); ++t) {
            k[t] = this.variables.indexOf(z.get(t));
        }
        try {
            this.localScore(i, k);
        }
        catch (RuntimeException e) {
            TetradLogger.getInstance().forceLogMessage(e.getMessage());
            return true;
        }
        return false;
    }

    public DataModel getData() {
        return this.dataModel;
    }

    private double getStructurePrior(int parents) {
        if (FastMath.abs(this.getStructurePrior()) <= 0.0) {
            return 0.0;
        }
        double p = this.getStructurePrior() / (double)this.variables.size();
        return -((double)parents * FastMath.log(p) + (double)(this.variables.size() - parents) * FastMath.log(1.0 - p));
    }

    private List<Node> getVariableList(int[] indices) {
        ArrayList<Node> variables = new ArrayList<Node>();
        for (int i : indices) {
            variables.add(this.variables.get(i));
        }
        return variables;
    }

    private Map<Node, Integer> indexMap(List<Node> variables) {
        HashMap<Node, Integer> indexMap = new HashMap<Node, Integer>();
        for (int i = 0; variables.size() > i; ++i) {
            indexMap.put(variables.get(i), i);
        }
        return indexMap;
    }

    private List<Integer> getRows(int i, int[] parents) {
        if (this.dataModel == null) {
            return null;
        }
        ArrayList<Integer> rows = new ArrayList<Integer>();
        DataSet dataSet = (DataSet)this.dataModel;
        block0: for (int k = 0; k < dataSet.getNumRows(); ++k) {
            if (Double.isNaN(dataSet.getDouble(k, i))) continue;
            for (int p : parents) {
                if (Double.isNaN(dataSet.getDouble(k, p))) continue block0;
            }
            rows.add(k);
        }
        return rows;
    }

    private double partialCorrelation(Node x, Node y, List<Node> z, List<Integer> rows) {
        try {
            return StatUtils.partialCorrelation(MatrixUtils.convertCovToCorr(this.getCov(rows, this.indices(x, y, z))));
        }
        catch (Exception e) {
            return Double.NaN;
        }
    }

    private int[] indices(Node x, Node y, List<Node> z) {
        int[] indices = new int[z.size() + 2];
        indices[0] = this.indexMap.get(x);
        indices[1] = this.indexMap.get(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = this.indexMap.get(z.get(i));
        }
        return indices;
    }

    private Matrix getCov(List<Integer> rows, int[] cols) {
        int i;
        if (this.dataModel == null) {
            return this.matrix.getSelection(cols, cols);
        }
        DataSet dataSet = (DataSet)this.dataModel;
        Matrix cov = new Matrix(cols.length, cols.length);
        for (i = 0; i < cols.length; ++i) {
            for (int j = i + 1; j < cols.length; ++j) {
                double mui = 0.0;
                double muj = 0.0;
                for (int k : rows) {
                    mui += dataSet.getDouble(k, cols[i]);
                    muj += dataSet.getDouble(k, cols[j]);
                }
                mui /= (double)(rows.size() - 1);
                muj /= (double)(rows.size() - 1);
                double _cov = 0.0;
                for (int k : rows) {
                    _cov += (dataSet.getDouble(k, cols[i]) - mui) * (dataSet.getDouble(k, cols[j]) - muj);
                }
                double mean = _cov / (double)rows.size();
                cov.set(i, j, mean);
                cov.set(j, i, mean);
            }
        }
        for (i = 0; i < cols.length; ++i) {
            double mui = 0.0;
            for (int k : rows) {
                mui += dataSet.getDouble(k, cols[i]);
            }
            mui /= (double)rows.size();
            double _cov = 0.0;
            for (int k : rows) {
                _cov += (dataSet.getDouble(k, cols[i]) - mui) * (dataSet.getDouble(k, cols[i]) - mui);
            }
            double mean = _cov / (double)rows.size();
            cov.set(i, i, mean);
        }
        return cov;
    }

    public void setRuleType(RuleType ruleType) {
        this.ruleType = ruleType;
    }

    public SemBicScore subset(List<Node> pi2) {
        int[] cols = new int[pi2.size()];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = this.variables.indexOf(pi2.get(i));
        }
        ICovarianceMatrix cov = this.getCovariances().getSubmatrix(cols);
        return new SemBicScore(cov);
    }

    @Override
    public String toString() {
        return "SEM BIC Score";
    }

    public static enum RuleType {
        CHICKERING,
        NANDY;

    }
}

