/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.DiscreteClassifier;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.TetradSerializable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

public final class BayesUpdaterClassifier
implements DiscreteClassifier,
TetradSerializable {
    static final long serialVersionUID = 23L;
    private final BayesIm bayesIm;
    private final DataSet testData;
    private double percentCorrect;
    private DiscreteVariable targetVariable;
    private final List<Node> bayesImVars;
    private int[] classifications;
    private double[][] marginals;
    private int numCases = -1;
    private int totalUsableCases;

    public BayesUpdaterClassifier(BayesIm bayesIm, DataSet testData) {
        if (bayesIm == null) {
            throw new IllegalArgumentException("BayesIm must not be null.");
        }
        if (testData == null) {
            throw new IllegalArgumentException("DataSet must not be null.");
        }
        this.bayesIm = bayesIm;
        this.testData = testData;
        this.percentCorrect = Double.NaN;
        this.bayesImVars = new LinkedList<Node>(bayesIm.getVariables());
    }

    public static BayesUpdaterClassifier serializableInstance() {
        return new BayesUpdaterClassifier(MlBayesIm.serializableInstance(), DataUtils.discreteSerializableInstance());
    }

    public void setTarget(String target) {
        DiscreteVariable targetVariable = null;
        for (int j = 0; j < this.getBayesImVars().size(); ++j) {
            DiscreteVariable dv = (DiscreteVariable)this.getBayesImVars().get(j);
            if (!dv.getName().equals(target)) continue;
            targetVariable = dv;
            break;
        }
        if (targetVariable == null) {
            throw new IllegalArgumentException("Not an available target: " + target);
        }
        this.targetVariable = targetVariable;
    }

    @Override
    public int[] classify() {
        if (this.targetVariable == null) {
            throw new NullPointerException("Target not set.");
        }
        RowSummingExactUpdater bayesUpdater = new RowSummingExactUpdater(this.getBayesIm());
        int nvars = this.getBayesImVars().size();
        int ncases = this.testData.getNumRows();
        int[] varIndices = new int[nvars];
        List<Node> dataVars = this.testData.getVariables();
        for (int i = 0; i < nvars; ++i) {
            DiscreteVariable variable = (DiscreteVariable)this.getBayesImVars().get(i);
            if (variable == this.targetVariable) continue;
            varIndices[i] = dataVars.indexOf(variable);
            if (varIndices[i] != -1) continue;
            throw new IllegalArgumentException("Can't find the (non-target) variable " + variable + " in the data. Either it's not there, or else its categories are in a different order.");
        }
        DataSet selectedData = this.testData.subsetColumns(varIndices);
        this.numCases = ncases;
        int[] estimatedValues = new int[ncases];
        int numTargetCategories = this.targetVariable.getNumCategories();
        double[][] probOfClassifiedValues = new double[numTargetCategories][ncases];
        Arrays.fill(estimatedValues, -1);
        for (int i = 0; i < ncases; ++i) {
            Evidence evidence = Evidence.tautology(this.getBayesIm());
            int itarget = evidence.getNodeIndex(this.targetVariable.getName());
            evidence.getProposition().setVariable(itarget, true);
            for (int j = 0; j < this.getBayesImVars().size(); ++j) {
                int observedValue;
                if (j == this.getBayesImVars().indexOf(this.targetVariable) || (observedValue = selectedData.getInt(i, j)) == -99) continue;
                String jName = this.getBayesImVars().get(j).getName();
                int jIndex = evidence.getNodeIndex(jName);
                evidence.getProposition().setCategory(jIndex, observedValue);
            }
            bayesUpdater.setEvidence(evidence);
            Node targetNode = this.getBayesIm().getNode(this.targetVariable.getName());
            int indexTargetBN = this.getBayesIm().getNodeIndex(targetNode);
            int estimatedValue = -1;
            double highestProb = -0.1;
            for (int j = 0; j < numTargetCategories; ++j) {
                double marginal;
                probOfClassifiedValues[j][i] = marginal = bayesUpdater.getMarginal(indexTargetBN, j);
                if (!(marginal >= highestProb)) continue;
                highestProb = marginal;
                estimatedValue = j;
            }
            if (estimatedValue < 0) {
                TetradLogger.getInstance().log("details", "Case " + i + " does not return valid marginal.");
                for (int m = 0; m < nvars; ++m) {
                    TetradLogger.getInstance().log("details", "  " + selectedData.getDouble(i, m));
                }
                estimatedValues[i] = -99;
                continue;
            }
            estimatedValues[i] = estimatedValue;
        }
        this.classifications = estimatedValues;
        this.marginals = probOfClassifiedValues;
        return estimatedValues;
    }

    @Override
    public int[][] crossTabulation() {
        int[] estimatedValues = this.classify();
        Node variable = this.testData.getVariable(this.targetVariable.getName());
        int varIndex = this.testData.getVariables().indexOf(variable);
        if (variable == null) {
            return null;
        }
        int ncases = this.testData.getNumRows();
        int nvalues = this.targetVariable.getNumCategories();
        int[][] crosstabs = new int[nvalues][nvalues];
        for (int i = 0; i < nvalues; ++i) {
            for (int j = 0; j < nvalues; ++j) {
                crosstabs[i][j] = 0;
            }
        }
        int numberCorrect = 0;
        int ntot = 0;
        for (int i = 0; i < ncases; ++i) {
            int estimatedValue = estimatedValues[i];
            int observedValue = this.testData.getInt(i, varIndex);
            if (estimatedValue < 0 || observedValue < 0) continue;
            ++ntot;
            int[] nArray = crosstabs[observedValue];
            int n = estimatedValue;
            nArray[n] = nArray[n] + 1;
            if (observedValue != estimatedValue) continue;
            ++numberCorrect;
        }
        this.percentCorrect = 100.0 * (double)numberCorrect / (double)ncases;
        this.totalUsableCases = ntot;
        return crosstabs;
    }

    @Override
    public double getPercentCorrect() {
        if (Double.isNaN(this.percentCorrect)) {
            this.crossTabulation();
        }
        return this.percentCorrect;
    }

    public DiscreteVariable getTargetVariable() {
        return this.targetVariable;
    }

    public BayesIm getBayesIm() {
        return this.bayesIm;
    }

    public DataSet getTestData() {
        return this.testData;
    }

    public int[] getClassifications() {
        return this.classifications;
    }

    public double[][] getMarginals() {
        return this.marginals;
    }

    public int getNumCases() {
        return this.numCases;
    }

    public int getTotalUsableCases() {
        return this.totalUsableCases;
    }

    public List<Node> getBayesImVars() {
        return this.bayesImVars;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.bayesIm == null) {
            throw new NullPointerException();
        }
        if (this.testData == null) {
            throw new NullPointerException();
        }
        if (this.getBayesImVars() == null) {
            throw new NullPointerException();
        }
    }
}

