/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.performance;

import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.RandomGraph;
import edu.cmu.tetrad.performance.ComparisonParameters;
import edu.cmu.tetrad.performance.ComparisonResult;
import edu.cmu.tetrad.search.BDeuScore;
import edu.cmu.tetrad.search.Cpc;
import edu.cmu.tetrad.search.Fci;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.GFci;
import edu.cmu.tetrad.search.GraphSearch;
import edu.cmu.tetrad.search.IndTestChiSquare;
import edu.cmu.tetrad.search.IndTestFisherZ;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.Pc;
import edu.cmu.tetrad.search.PcLocal;
import edu.cmu.tetrad.search.PcStableMax;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.search.SemBicScore;
import edu.cmu.tetrad.sem.LargeScaleSimulation;
import edu.cmu.tetrad.sem.ScoreType;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.TextTable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;

public class Comparison {
    public static ComparisonResult compare(ComparisonParameters params) {
        GraphSearch search;
        Graph trueDag;
        DataSet dataSet;
        IndependenceTest test = null;
        Score score = null;
        ComparisonResult result = new ComparisonResult(params);
        if (params.getDataFile() != null) {
            dataSet = Comparison.loadDataFile();
            if (params.getGraphFile() == null) {
                throw new IllegalArgumentException("True graph file not set.");
            }
            trueDag = Comparison.loadGraphFile();
        } else {
            ArrayList<Node> nodes;
            if (params.getNumVars() == -1) {
                throw new IllegalArgumentException("Number of variables not set.");
            }
            if (params.getNumEdges() == -1) {
                throw new IllegalArgumentException("Number of edges not set.");
            }
            if (params.getDataType() == ComparisonParameters.DataType.Continuous) {
                nodes = new ArrayList<Node>();
                for (int i = 0; i < params.getNumVars(); ++i) {
                    nodes.add(new ContinuousVariable("X" + (i + 1)));
                }
                trueDag = RandomGraph.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
                if (params.getDataType() == null) {
                    throw new IllegalArgumentException("Data type not set or inferred.");
                }
                if (params.getSampleSize() == -1) {
                    throw new IllegalArgumentException("Sample size not set.");
                }
                LargeScaleSimulation sim = new LargeScaleSimulation(trueDag);
                dataSet = sim.simulateDataFisher(params.getSampleSize());
            } else if (params.getDataType() == ComparisonParameters.DataType.Discrete) {
                nodes = new ArrayList();
                for (int i = 0; i < params.getNumVars(); ++i) {
                    nodes.add(new DiscreteVariable("X" + (i + 1), 3));
                }
                trueDag = RandomGraph.randomGraphRandomForwardEdges(nodes, 0, params.getNumEdges(), 10, 10, 10, false, true);
                if (params.getDataType() == null) {
                    throw new IllegalArgumentException("Data type not set or inferred.");
                }
                if (params.getSampleSize() == -1) {
                    throw new IllegalArgumentException("Sample size not set.");
                }
                int[] tiers = new int[nodes.size()];
                for (int i = 0; i < nodes.size(); ++i) {
                    tiers[i] = i;
                }
                BayesPm pm = new BayesPm(trueDag, 3, 3);
                MlBayesIm im = new MlBayesIm(pm, 1);
                dataSet = im.simulateData(params.getSampleSize(), false, tiers);
            } else {
                throw new IllegalArgumentException("Unrecognized data type.");
            }
            if (dataSet == null) {
                throw new IllegalArgumentException("No data set.");
            }
        }
        if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.FisherZ) {
            if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
                throw new IllegalArgumentException("Data type previously set to something other than continuous.");
            }
            if (Double.isNaN(params.getAlpha())) {
                throw new IllegalArgumentException("Alpha not set.");
            }
            assert (dataSet != null);
            test = new IndTestFisherZ(dataSet, params.getAlpha());
            params.setDataType(ComparisonParameters.DataType.Continuous);
        } else if (params.getIndependenceTest() == ComparisonParameters.IndependenceTestType.ChiSquare) {
            if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
                throw new IllegalArgumentException("Data type previously set to something other than discrete.");
            }
            if (Double.isNaN(params.getAlpha())) {
                throw new IllegalArgumentException("Alpha not set.");
            }
            assert (dataSet != null);
            test = new IndTestChiSquare(dataSet, params.getAlpha());
            params.setDataType(ComparisonParameters.DataType.Discrete);
        }
        if (params.getScore() == ScoreType.SemBic) {
            if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Continuous) {
                throw new IllegalArgumentException("Data type previously set to something other than continuous.");
            }
            if (Double.isNaN(params.getPenaltyDiscount())) {
                throw new IllegalArgumentException("Penalty discount not set.");
            }
            SemBicScore semBicScore = new SemBicScore(new CovarianceMatrix(dataSet));
            semBicScore.setPenaltyDiscount(params.getPenaltyDiscount());
            score = semBicScore;
            params.setDataType(ComparisonParameters.DataType.Continuous);
        } else if (params.getScore() == ScoreType.BDeu) {
            if (params.getDataType() != null && params.getDataType() != ComparisonParameters.DataType.Discrete) {
                throw new IllegalArgumentException("Data type previously set to something other than discrete.");
            }
            if (Double.isNaN(params.getSamplePrior())) {
                throw new IllegalArgumentException("Sample prior not set.");
            }
            if (Double.isNaN(params.getStructurePrior())) {
                throw new IllegalArgumentException("Structure prior not set.");
            }
            score = new BDeuScore(dataSet);
            ((BDeuScore)score).setSamplePrior(params.getSamplePrior());
            ((BDeuScore)score).setStructurePrior(params.getStructurePrior());
            params.setDataType(ComparisonParameters.DataType.Discrete);
            params.setDataType(ComparisonParameters.DataType.Discrete);
        }
        if (params.getAlgorithm() == null) {
            throw new IllegalArgumentException("Algorithm not set.");
        }
        long time1 = MillisecondTimes.timeMillis();
        if (params.getAlgorithm() == ComparisonParameters.Algorithm.PC) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            search = new Pc(test);
            result.setResultGraph(((Pc)search).search());
            result.setCorrectResult(SearchGraphUtils.cpdagForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.CPC) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            search = new Cpc(test);
            result.setResultGraph(((Cpc)search).search());
            result.setCorrectResult(SearchGraphUtils.cpdagForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCLocal) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            search = new PcLocal(test);
            result.setResultGraph(((PcLocal)search).search());
            result.setCorrectResult(SearchGraphUtils.cpdagForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.PCStableMax) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            search = new PcStableMax(test);
            result.setResultGraph(((PcStableMax)search).search());
            result.setCorrectResult(SearchGraphUtils.cpdagForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES) {
            if (score == null) {
                throw new IllegalArgumentException("Score not set.");
            }
            search = new Fges(score);
            ((Fges)search).setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
            result.setResultGraph(((Fges)search).search());
            result.setCorrectResult(SearchGraphUtils.cpdagForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FGES2) {
            if (score == null) {
                throw new IllegalArgumentException("Score not set.");
            }
            search = new Fges(score);
            ((Fges)search).setFaithfulnessAssumed(params.isOneEdgeFaithfulnessAssumed());
            result.setResultGraph(((Fges)search).search());
            result.setCorrectResult(SearchGraphUtils.cpdagForDag(new EdgeListGraph(trueDag)));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.FCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            search = new Fci(test);
            result.setResultGraph(((Fci)search).search());
            result.setCorrectResult(SearchGraphUtils.dagToPag(trueDag));
        } else if (params.getAlgorithm() == ComparisonParameters.Algorithm.GFCI) {
            if (test == null) {
                throw new IllegalArgumentException("Test not set.");
            }
            search = new GFci(test, score);
            result.setResultGraph(((GFci)search).search());
            result.setCorrectResult(SearchGraphUtils.dagToPag(trueDag));
        } else {
            throw new IllegalArgumentException("Unrecognized algorithm.");
        }
        long time2 = MillisecondTimes.timeMillis();
        long elapsed = time2 - time1;
        result.setElapsed(elapsed);
        result.setTrueDag(trueDag);
        return result;
    }

    private static Graph loadGraphFile() {
        return null;
    }

    private static DataSet loadDataFile() {
        return null;
    }

    public static String summarize(List<ComparisonResult> results, List<TableColumn> tableColumns) {
        ArrayList<Node> variables = new ArrayList<Node>();
        for (TableColumn column : tableColumns) {
            variables.add(new ContinuousVariable(column.toString()));
        }
        BoxDataSet dataSet = new BoxDataSet(new DoubleDataBox(0, variables.size()), variables);
        dataSet.setNumberFormat(new DecimalFormat("0"));
        for (int i = 0; i < results.size(); ++i) {
            System.out.println("\nRun " + (i + 1) + "\n" + results.get(i));
        }
        System.out.println();
        for (ComparisonResult _result : results) {
            Graph correctGraph = _result.getCorrectResult();
            Graph resultGraph = _result.getResultGraph();
            GraphUtils.GraphComparison comparison = SearchGraphUtils.getGraphComparison2(correctGraph, resultGraph);
            int newRow = dataSet.getNumRows();
            if (tableColumns.contains((Object)TableColumn.AdjCor)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AdjCor), comparison.getAdjCor());
            }
            if (tableColumns.contains((Object)TableColumn.AdjFn)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AdjFn), comparison.getAdjFn());
            }
            if (tableColumns.contains((Object)TableColumn.AdjFp)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AdjFp), comparison.getAdjFp());
            }
            if (tableColumns.contains((Object)TableColumn.AhdCor)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AhdCor), comparison.getAhdCor());
            }
            if (tableColumns.contains((Object)TableColumn.AhdFn)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AhdFn), comparison.getAhdFn());
            }
            if (tableColumns.contains((Object)TableColumn.AhdFp)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AhdFp), comparison.getAhdFp());
            }
            if (tableColumns.contains((Object)TableColumn.AdjPrec)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AdjPrec), comparison.getAdjPrec());
            }
            if (tableColumns.contains((Object)TableColumn.AdjRec)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AdjRec), comparison.getAdjRec());
            }
            if (tableColumns.contains((Object)TableColumn.AhdPrec)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AhdPrec), comparison.getAhdPrec());
            }
            if (tableColumns.contains((Object)TableColumn.AhdRec)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.AhdRec), comparison.getAhdRec());
            }
            if (tableColumns.contains((Object)TableColumn.Elapsed)) {
                dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.Elapsed), _result.getElapsed());
            }
            if (!tableColumns.contains((Object)TableColumn.SHD)) continue;
            dataSet.setDouble(newRow, tableColumns.indexOf((Object)TableColumn.SHD), comparison.getShd());
        }
        int[] cols = new int[tableColumns.size()];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = i;
        }
        return Comparison.getTextTable(dataSet, cols, new DecimalFormat("0.00")).toString();
    }

    private static TextTable getTextTable(DataSet dataSet, int[] columns, NumberFormat nf) {
        int j;
        int i;
        TextTable table = new TextTable(dataSet.getNumRows() + 2, columns.length + 1);
        table.setToken(0, 0, "Run #");
        for (int j2 = 0; j2 < columns.length; ++j2) {
            table.setToken(0, j2 + 1, dataSet.getVariable(columns[j2]).getName());
        }
        for (i = 0; i < dataSet.getNumRows(); ++i) {
            table.setToken(i + 1, 0, Integer.toString(i + 1));
        }
        for (i = 0; i < dataSet.getNumRows(); ++i) {
            for (j = 0; j < columns.length; ++j) {
                table.setToken(i + 1, j + 1, nf.format(dataSet.getDouble(i, columns[j])));
            }
        }
        DecimalFormat nf2 = new DecimalFormat("0.00");
        for (j = 0; j < columns.length; ++j) {
            double sum = 0.0;
            for (int i2 = 0; i2 < dataSet.getNumRows(); ++i2) {
                sum += dataSet.getDouble(i2, columns[j]);
            }
            double avg = sum / (double)dataSet.getNumRows();
            table.setToken(dataSet.getNumRows() + 2 - 1, j + 1, nf2.format(avg));
        }
        table.setToken(dataSet.getNumRows() + 2 - 1, 0, "Avg");
        return table;
    }

    public static enum TableColumn {
        AdjCor,
        AdjFn,
        AdjFp,
        AhdCor,
        AhdFn,
        AhdFp,
        SHD,
        AdjPrec,
        AdjRec,
        AhdPrec,
        AhdRec,
        Elapsed;

    }
}

