/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.ISemBicScore;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.search.SemBicScore;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.StatUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;

public class SemBicScoreMultiFas
implements ISemBicScore {
    private final List<SemBicScore> semBicScores;
    private final List<Node> variables;
    private final int sampleSize;
    private double penaltyDiscount = 2.0;
    private boolean verbose;
    private final Map<String, Integer> indexMap;
    private final Map<Score, ICovarianceMatrix> covMap;

    public SemBicScoreMultiFas(List<DataModel> dataModels) {
        if (dataModels == null) {
            throw new NullPointerException();
        }
        ArrayList<SemBicScore> semBicScores = new ArrayList<SemBicScore>();
        for (DataModel model : dataModels) {
            if (model instanceof DataSet) {
                DataSet dataSet = (DataSet)model;
                if (!dataSet.isContinuous()) {
                    throw new IllegalArgumentException("Datasets must be continuous.");
                }
                SemBicScore semBicScore = new SemBicScore(new CovarianceMatrix(dataSet));
                semBicScore.setPenaltyDiscount(this.penaltyDiscount);
                semBicScores.add(semBicScore);
                continue;
            }
            if (model instanceof ICovarianceMatrix) {
                SemBicScore semBicScore = new SemBicScore((ICovarianceMatrix)model);
                semBicScore.setPenaltyDiscount(this.penaltyDiscount);
                semBicScores.add(semBicScore);
                continue;
            }
            throw new IllegalArgumentException("Only continuous data sets and covariance matrices may be used as input.");
        }
        List<Node> variables = ((SemBicScore)semBicScores.get(0)).getVariables();
        for (int i = 2; i < semBicScores.size(); ++i) {
            ((SemBicScore)semBicScores.get(i)).setVariables(variables);
        }
        this.semBicScores = semBicScores;
        this.variables = variables;
        this.sampleSize = ((SemBicScore)semBicScores.get(0)).getSampleSize();
        this.indexMap = this.indexMap(this.variables);
        this.covMap = this.covMap(this.semBicScores);
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        double sum = 0.0;
        Node _x = this.variables.get(x);
        Node _y = this.variables.get(y);
        List<Node> _z = this.getVariableList(z);
        for (SemBicScore score : this.semBicScores) {
            double r;
            try {
                r = this.partialCorrelation(_x, _y, _z, score);
            }
            catch (SingularMatrixException e) {
                return Double.NaN;
            }
            int p = 2 + z.length;
            int N = this.covMap.get(score).getSampleSize();
            sum += (double)(-N) * FastMath.log(1.0 - r * r) - (double)p * this.getPenaltyDiscount() * FastMath.log(N);
        }
        return sum;
    }

    @Override
    public double localScoreDiff(int x, int y) {
        return this.localScoreDiff(x, y, new int[0]);
    }

    @Override
    public double localScore(int i, int[] parents) {
        double sum = 0.0;
        int count = 0;
        for (SemBicScore score : this.semBicScores) {
            double _score = score.localScore(i, parents);
            if (Double.isNaN(_score)) continue;
            sum += _score;
            ++count;
        }
        double score = sum / (double)count;
        if (Double.isNaN(score) || Double.isInfinite(score)) {
            return Double.NaN;
        }
        return score;
    }

    public double localScore(int i, int[] parents, int index) {
        return this.localScoreOneDataSet(i, parents, index);
    }

    private double localScoreOneDataSet(int i, int[] parents, int index) {
        return this.semBicScores.get(index).localScore(i, parents);
    }

    @Override
    public double localScore(int i, int parent) {
        double sum = 0.0;
        int count = 0;
        for (SemBicScore score : this.semBicScores) {
            double _score = score.localScore(i, parent);
            if (Double.isNaN(_score)) continue;
            sum += _score;
            ++count;
        }
        return sum / (double)count;
    }

    @Override
    public double localScore(int i) {
        double sum = 0.0;
        int count = 0;
        for (SemBicScore score : this.semBicScores) {
            double _score = score.localScore(i);
            if (Double.isNaN(_score)) continue;
            sum += _score;
            ++count;
        }
        return sum / (double)count;
    }

    @Override
    public double getPenaltyDiscount() {
        return this.penaltyDiscount;
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > -0.25 * this.getPenaltyDiscount() * FastMath.log(this.sampleSize);
    }

    public DataSet getDataSet() {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setPenaltyDiscount(double penaltyDiscount) {
        this.penaltyDiscount = penaltyDiscount;
        for (SemBicScore score : this.semBicScores) {
            score.setPenaltyDiscount(penaltyDiscount);
        }
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    public boolean getAlternativePenalty() {
        return false;
    }

    @Override
    public int getSampleSize() {
        return this.sampleSize;
    }

    private List<Node> getVariableList(int[] indices) {
        ArrayList<Node> variables = new ArrayList<Node>();
        for (int i : indices) {
            variables.add(this.variables.get(i));
        }
        return variables;
    }

    private double partialCorrelation(Node x, Node y, List<Node> z, SemBicScore score) throws SingularMatrixException {
        int[] indices = new int[z.size() + 2];
        indices[0] = this.indexMap.get(x.getName());
        indices[1] = this.indexMap.get(y.getName());
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = this.indexMap.get(z.get(i).getName());
        }
        Matrix submatrix = this.covMap.get(score).getSubmatrix(indices).getMatrix();
        return StatUtils.partialCorrelation(submatrix);
    }

    private Map<String, Integer> indexMap(List<Node> variables) {
        HashMap<String, Integer> indexMap = new HashMap<String, Integer>();
        for (int i = 0; i < variables.size(); ++i) {
            indexMap.put(variables.get(i).getName(), i);
        }
        return indexMap;
    }

    private Map<Score, ICovarianceMatrix> covMap(List<SemBicScore> scores) {
        HashMap<Score, ICovarianceMatrix> covMap = new HashMap<Score, ICovarianceMatrix>();
        Iterator<SemBicScore> iterator = scores.iterator();
        while (iterator.hasNext()) {
            SemBicScore semBicScore;
            SemBicScore score = semBicScore = iterator.next();
            covMap.put(score, score.getCovariances());
        }
        return covMap;
    }

    @Override
    public Node getVariable(String targetName) {
        for (Node node : this.variables) {
            if (!node.getName().equals(targetName)) continue;
            return node;
        }
        return null;
    }

    @Override
    public int getMaxDegree() {
        return 1000;
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        return false;
    }
}

