/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.MlBayesEstimator;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.DiscreteClassifier;
import edu.cmu.tetrad.search.HitonOld;
import edu.cmu.tetrad.search.IndTestChiSquare;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

public final class HitonClassifier
implements DiscreteClassifier {
    private DataSet trainingData;
    private DataSet testingData;
    private String target;
    private double percentCorrect;
    private double alpha;
    private int depth;
    private DiscreteVariable targetVariable;
    private List<Node> markovBlanketNodes;

    public HitonClassifier(DataSet ddsTrain, DataSet ddsClassify, String target, double alpha, int depth) {
        this.trainingData = ddsTrain;
        this.testingData = ddsClassify;
        this.target = target;
        this.percentCorrect = Double.NaN;
        this.alpha = alpha;
        this.depth = depth;
        List<Node> trainVars = ddsTrain.getVariables();
        List<Node> classifyVars = ddsClassify.getVariables();
        for (int i = 0; i < trainVars.size(); ++i) {
            if (((Object)trainVars.get(i)).equals(classifyVars.get(i))) continue;
            throw new IllegalArgumentException("Datasets must contain same vars.");
        }
        this.targetVariable = null;
        for (Node trainVar : trainVars) {
            DiscreteVariable dv = (DiscreteVariable)trainVar;
            if (!dv.getName().equals(target)) continue;
            this.targetVariable = dv;
            break;
        }
        if (this.targetVariable == null) {
            throw new IllegalArgumentException("Target variable not in data.");
        }
    }

    @Override
    public int[] classify() {
        IndTestChiSquare test = new IndTestChiSquare(this.trainingData, this.alpha);
        HitonOld hitons = new HitonOld(test, this.depth);
        Graph markovBlanket = hitons.search(this.target);
        List<Node> nodes = markovBlanket.getNodes();
        LinkedList<Node> variables = new LinkedList<Node>();
        for (Node node : nodes) {
            nodes.add(node);
        }
        this.markovBlanketNodes = variables;
        DataSet ddsMBTrain = this.trainingData.subsetColumns(this.markovBlanketNodes);
        List<Node> varsTrain = ddsMBTrain.getVariables();
        Dag mbDag = new Dag(markovBlanket);
        BayesPm mbBayesPm = new BayesPm(mbDag);
        for (int i = 0; i < varsTrain.size(); ++i) {
            DiscreteVariable dv = (DiscreteVariable)varsTrain.get(i);
            int ncats = dv.getNumCategories();
            mbBayesPm.setNumCategories(this.markovBlanketNodes.get(i), ncats);
        }
        MlBayesEstimator estimator = new MlBayesEstimator();
        MlBayesIm mbBayesIm = (MlBayesIm)estimator.estimate(mbBayesPm, ddsMBTrain);
        RowSummingExactUpdater bayesUpdaterMb = new RowSummingExactUpdater(mbBayesIm);
        DataSet ddsMBClassify = this.testingData.subsetColumns(this.markovBlanketNodes);
        int ncases = ddsMBClassify.getNumRows();
        int[] estimatedValues = new int[ncases];
        Arrays.fill(estimatedValues, -1);
        List<Node> varsClassify = ddsMBClassify.getVariables();
        for (int i = 0; i < ncases; ++i) {
            Evidence evidence = Evidence.tautology(mbBayesIm);
            int itarget = evidence.getNodeIndex(this.target);
            evidence.getProposition().setVariable(itarget, true);
            for (int j = 0; j < varsClassify.size(); ++j) {
                if (j == varsClassify.indexOf(this.targetVariable)) continue;
                String other = varsClassify.get(j).getName();
                int iother = evidence.getNodeIndex(other);
                evidence.getProposition().setCategory(iother, ddsMBClassify.getInt(i, j));
            }
            bayesUpdaterMb.setEvidence(evidence);
            BayesIm updatedIM = bayesUpdaterMb.getBayesIm();
            int indexTargetBN = updatedIM.getNodeIndex(this.targetVariable);
            double highestProb = -0.1;
            int estimatedValue = -1;
            for (int k = 0; k < this.targetVariable.getNumCategories(); ++k) {
                if (!(bayesUpdaterMb.getMarginal(indexTargetBN, k) >= highestProb)) continue;
                highestProb = bayesUpdaterMb.getMarginal(indexTargetBN, k);
                estimatedValue = k;
            }
            if (estimatedValue < 0) {
                TetradLogger.getInstance().log("details", "Case " + i + " does not return valid marginal.");
                continue;
            }
            estimatedValues[i] = estimatedValue;
        }
        return estimatedValues;
    }

    @Override
    public int[][] crossTabulation() {
        int[] estimatedValues = this.classify();
        DataSet ddsMBClassify = this.testingData.subsetColumns(this.markovBlanketNodes);
        List<Node> varsClassify = ddsMBClassify.getVariables();
        int indexTargetDDS = varsClassify.indexOf(this.targetVariable);
        int nvalues = this.targetVariable.getNumCategories();
        int[][] crosstabs = new int[nvalues][nvalues];
        for (int i = 0; i < nvalues; ++i) {
            for (int j = 0; j < nvalues; ++j) {
                crosstabs[i][j] = 0;
            }
        }
        int ntot = 0;
        int ncases = ddsMBClassify.getNumRows();
        int numberCorrect = 0;
        for (int i = 0; i < ncases; ++i) {
            int estimatedValue = estimatedValues[i];
            int observedValue = ddsMBClassify.getInt(i, indexTargetDDS);
            if (estimatedValue < 0) continue;
            ++ntot;
            int[] nArray = crosstabs[observedValue];
            int n = estimatedValue;
            nArray[n] = nArray[n] + 1;
            if (observedValue != estimatedValue) continue;
            ++numberCorrect;
        }
        this.percentCorrect = 100.0 * (double)numberCorrect / (double)ncases;
        TetradLogger.getInstance().log("details", "Total no usable cases= " + ntot + " out of " + ncases);
        return crosstabs;
    }

    @Override
    public double getPercentCorrect() {
        if (Double.isNaN(this.percentCorrect)) {
            this.crossTabulation();
        }
        return this.percentCorrect;
    }

    public DiscreteVariable getTargetVariable() {
        return this.targetVariable;
    }
}

