/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.DirichletBayesIm;
import edu.cmu.tetrad.bayes.DirichletEstimator;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.Proposition;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.DiscreteClassifier;
import edu.cmu.tetrad.search.IndTestChiSquare;
import edu.cmu.tetrad.search.MbUtils;
import edu.cmu.tetrad.search.Pc;
import edu.cmu.tetrad.search.PcMb;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.TetradLogger;
import edu.pitt.dbmi.data.reader.Delimiter;
import java.io.File;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.List;

public class MbClassify
implements DiscreteClassifier {
    private DataSet train;
    private DataSet test;
    private Node target;
    private double alpha;
    private int depth;
    private double prior;
    private int maxMissing;
    private DiscreteVariable targetVariable;
    private double percentCorrect;
    private int[][] crossTabulation;

    public MbClassify(String trainPath, String testPath, String targetString, String alphaString, String depthString, String priorString, String maxMissingString) {
        try {
            String s = "MbClassify " + trainPath + " " + testPath + " " + targetString + " " + alphaString + " " + depthString + " " + priorString + " " + maxMissingString + " ";
            TetradLogger.getInstance().log("info", s);
            DataSet train = SimpleDataLoader.loadContinuousData(new File(trainPath), "//", '\"', "*", true, Delimiter.TAB);
            DataSet test = SimpleDataLoader.loadContinuousData(new File(testPath), "//", '\"', "*", true, Delimiter.TAB);
            double alpha = Double.parseDouble(alphaString);
            int depth = Integer.parseInt(depthString);
            double prior = Double.parseDouble(priorString);
            int maxMissing = Integer.parseInt(maxMissingString);
            this.setup(train, test, this.target, alpha, depth, prior, maxMissing);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private void setup(DataSet train, DataSet test, Node target, double alpha, int depth, double prior, int maxMissing) {
        this.train = train;
        this.test = test;
        this.alpha = alpha;
        this.target = target;
        this.depth = depth;
        this.prior = prior;
        this.maxMissing = maxMissing;
        this.targetVariable = (DiscreteVariable)target;
        if (this.targetVariable == null) {
            throw new IllegalArgumentException("Target variable not in data: " + target);
        }
    }

    @Override
    public int[] classify() {
        IndTestChiSquare indTest = new IndTestChiSquare(this.train, this.alpha);
        PcMb search = new PcMb(indTest, this.depth);
        search.setDepth(this.depth);
        List<Node> mbPlusTarget = search.findMb(this.target);
        mbPlusTarget.add(this.target);
        DataSet subset = this.train.subsetColumns(mbPlusTarget);
        System.out.println("subset vars = " + subset.getVariables());
        Pc cpdagSearch = new Pc(new IndTestChiSquare(subset, 0.05));
        Graph mbCPDAG = cpdagSearch.search();
        TetradLogger.getInstance().log("details", "CPDAG = " + mbCPDAG);
        MbUtils.trimToMbNodes(mbCPDAG, this.target, true);
        TetradLogger.getInstance().log("details", "Trimmed CPDAG = " + mbCPDAG);
        for (Edge edge : mbCPDAG.getEdges()) {
            if (!Edges.isBidirectedEdge(edge)) continue;
            mbCPDAG.removeEdge(edge);
        }
        Graph selectedDag = MbUtils.getOneMbDag(mbCPDAG);
        TetradLogger.getInstance().log("details", "Selected DAG = " + selectedDag);
        TetradLogger.getInstance().log("details", "Vars = " + selectedDag.getNodes());
        TetradLogger.getInstance().log("details", "\nClassification using selected MB DAG:");
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        List<Node> mbNodes = selectedDag.getNodes();
        DataSet trainDataSubset = this.train.subsetColumns(mbNodes);
        BayesPm bayesPm = new BayesPm(selectedDag);
        List<Node> varsTrain = trainDataSubset.getVariables();
        for (int i1 = 0; i1 < varsTrain.size(); ++i1) {
            DiscreteVariable trainingVar = (DiscreteVariable)varsTrain.get(i1);
            bayesPm.setCategories(mbNodes.get(i1), trainingVar.getCategories());
        }
        TetradLogger.getInstance().log("info", "Estimating Bayes net; please wait...");
        DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPm, this.prior);
        DirichletBayesIm bayesIm = DirichletEstimator.estimate(prior, trainDataSubset);
        RowSummingExactUpdater updater = new RowSummingExactUpdater(bayesIm);
        DataSet testSubset = this.test.subsetColumns(mbNodes);
        int numCases = testSubset.getNumRows();
        int[] estimatedCategories = new int[numCases];
        Arrays.fill(estimatedCategories, -1);
        List<Node> varsClassify = testSubset.getVariables();
        for (int k = 0; k < numCases; ++k) {
            Proposition proposition = Proposition.tautology(bayesIm);
            int numMissing = 0;
            for (int testIndex = 0; testIndex < varsClassify.size(); ++testIndex) {
                int trainIndex;
                DiscreteVariable var = (DiscreteVariable)varsClassify.get(testIndex);
                if (var.equals(this.targetVariable) || (trainIndex = proposition.getNodeIndex(var.getName())) == -99) continue;
                int testValue = testSubset.getInt(k, testIndex);
                if (testValue == -99) {
                    ++numMissing;
                    continue;
                }
                proposition.setCategory(trainIndex, testValue);
            }
            if (numMissing > this.maxMissing) {
                TetradLogger.getInstance().log("details", "classification(" + k + ") = not done since number of missing values too high (" + numMissing + ").");
                continue;
            }
            Evidence evidence = Evidence.tautology(bayesIm);
            evidence.getProposition().restrictToProposition(proposition);
            updater.setEvidence(evidence);
            int targetIndex = proposition.getNodeIndex(this.targetVariable.getName());
            double highestProb = -0.1;
            int _category = -1;
            for (int category = 0; category < this.targetVariable.getNumCategories(); ++category) {
                double marginal = updater.getMarginal(targetIndex, category);
                if (!(marginal > highestProb)) continue;
                highestProb = marginal;
                _category = category;
            }
            if (_category < 0) {
                System.out.println("classification(" + k + ") is undefined (undefined marginals).");
                continue;
            }
            String estimatedCategory = this.targetVariable.getCategories().get(_category);
            TetradLogger.getInstance().log("details", "classification(" + k + ") = " + estimatedCategory);
            estimatedCategories[k] = _category;
        }
        int targetIndex = varsClassify.indexOf(this.targetVariable);
        int numCategories = this.targetVariable.getNumCategories();
        int[][] crossTabs = new int[numCategories][numCategories];
        int numberCorrect = 0;
        int numberCounted = 0;
        for (int k = 0; k < numCases; ++k) {
            int estimatedCategory = estimatedCategories[k];
            int observedValue = testSubset.getInt(k, targetIndex);
            if (estimatedCategory < 0) continue;
            int[] nArray = crossTabs[observedValue];
            int n = estimatedCategory;
            nArray[n] = nArray[n] + 1;
            ++numberCounted;
            if (observedValue != estimatedCategory) continue;
            ++numberCorrect;
        }
        double percentCorrect1 = 100.0 * (double)numberCorrect / (double)numberCounted;
        TetradLogger.getInstance().log("details", "");
        TetradLogger.getInstance().log("details", "\t\t\tEstimated\t");
        TetradLogger.getInstance().log("details", "Observed\t");
        StringBuilder buf0 = new StringBuilder();
        buf0.append("\t");
        for (int m = 0; m < numCategories; ++m) {
            buf0.append(this.targetVariable.getCategory(m)).append("\t");
        }
        TetradLogger.getInstance().log("details", buf0.toString());
        for (int k = 0; k < numCategories; ++k) {
            StringBuilder buf = new StringBuilder();
            buf.append(this.targetVariable.getCategory(k)).append("\t");
            for (int m = 0; m < numCategories; ++m) {
                buf.append(crossTabs[k][m]).append("\t");
            }
            TetradLogger.getInstance().log("details", buf.toString());
        }
        TetradLogger.getInstance().log("details", "");
        TetradLogger.getInstance().log("details", "Number correct = " + numberCorrect);
        TetradLogger.getInstance().log("details", "Number counted = " + numberCounted);
        TetradLogger.getInstance().log("details", "Percent correct = " + nf.format(percentCorrect1) + "%");
        this.crossTabulation = crossTabs;
        this.percentCorrect = percentCorrect1;
        return estimatedCategories;
    }

    @Override
    public int[][] crossTabulation() {
        return this.crossTabulation;
    }

    @Override
    public double getPercentCorrect() {
        return this.percentCorrect;
    }

    public static void main(String[] args) {
        String trainPath = args[0];
        String testPath = args[1];
        String targetString = args[2];
        String alphaString = args[3];
        String depthString = args[4];
        String priorString = args[5];
        String maxMissingString = args[6];
        new MbClassify(trainPath, testPath, targetString, alphaString, depthString, priorString, maxMissingString);
    }
}

