/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetradapp.editor;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.BayesUpdaterClassifier;
import edu.cmu.tetrad.util.JOptionUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RocCalculator;
import edu.cmu.tetradapp.editor.DataDisplay;
import edu.cmu.tetradapp.editor.RocPlot;
import edu.cmu.tetradapp.editor.SaveComponentImage;
import edu.cmu.tetradapp.model.BayesUpdaterClassifierWrapper;
import edu.cmu.tetradapp.util.WatchedProcess;
import edu.cmu.tetradapp.workbench.GraphWorkbench;
import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Component;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.Window;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.ItemEvent;
import java.awt.event.ItemListener;
import java.text.NumberFormat;
import java.util.LinkedList;
import java.util.List;
import javax.swing.Box;
import javax.swing.DefaultComboBoxModel;
import javax.swing.JButton;
import javax.swing.JComboBox;
import javax.swing.JLabel;
import javax.swing.JMenu;
import javax.swing.JMenuBar;
import javax.swing.JMenuItem;
import javax.swing.JOptionPane;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JTabbedPane;
import javax.swing.JTextArea;
import javax.swing.border.EmptyBorder;

public class BayesUpdaterClassifierEditor
extends JPanel {
    private final BayesUpdaterClassifier classifier;
    private JComboBox variableDropdown;
    private JTabbedPane tabbedPane;
    private JComboBox categoryDropdown;
    private GraphWorkbench workbench;
    private RocPlot rocPlot;
    private final JMenuItem saveRoc;

    private BayesUpdaterClassifierEditor(BayesUpdaterClassifier classifier) {
        if (classifier == null) {
            throw new NullPointerException();
        }
        this.classifier = classifier;
        this.setLayout(new BorderLayout());
        this.setPreferredSize(new Dimension(600, 600));
        Box b = Box.createVerticalBox();
        b.add(this.getToolbar());
        b.add(this.getDisplayPanel());
        this.add((Component)b, "Center");
        JMenuBar menuBar = new JMenuBar();
        JMenu file = new JMenu("File");
        menuBar.add(file);
        file.add(new SaveComponentImage(this.workbench, "Save Graph Image..."));
        this.saveRoc = new JMenuItem("Save ROC Plot Image...");
        this.saveRoc.setEnabled(false);
        this.saveRoc.addActionListener(new ActionListener(){

            @Override
            public void actionPerformed(ActionEvent e) {
                BayesUpdaterClassifierEditor.this.saveRocImage();
            }
        });
        file.add(this.saveRoc);
        this.add((Component)menuBar, "North");
        if (classifier.getClassifications() != null) {
            this.showClassification();
            this.showRocCurve();
            this.showConfusionMatrix();
        }
    }

    private void saveRocImage() {
        if (this.rocPlot == null) {
            JOptionPane.showMessageDialog(JOptionUtils.centeringComp(), "Nothing to save.");
            return;
        }
        SaveComponentImage action = new SaveComponentImage(this.rocPlot, "");
        action.actionPerformed(new ActionEvent(this, 1001, "Save"));
    }

    private Component getDisplayPanel() {
        JPanel panel = new JPanel();
        panel.setLayout(new BorderLayout());
        this.tabbedPane = new JTabbedPane();
        this.getTabbedPane().add("Graph", this.getGraphPanel());
        this.getTabbedPane().add("Test Data", this.getDataPanel());
        panel.add((Component)this.getTabbedPane(), "Center");
        return panel;
    }

    private Component getDataPanel() {
        DataSet dataSet = this.getClassifier().getTestData();
        DataDisplay jTable = new DataDisplay(dataSet);
        return new JScrollPane(jTable);
    }

    private Component getGraphPanel() {
        Graph graph = this.getClassifier().getBayesIm().getDag();
        this.workbench = new GraphWorkbench(graph);
        return new JScrollPane(this.workbench);
    }

    private Component getToolbar() {
        JButton classifyButton = new JButton("Classify");
        classifyButton.addActionListener(new ActionListener(){

            @Override
            public void actionPerformed(ActionEvent e) {
                Window owner = (Window)BayesUpdaterClassifierEditor.this.getTopLevelAncestor();
                new WatchedProcess(owner){

                    @Override
                    public void watch() {
                        BayesUpdaterClassifierEditor.this.doClassify();
                        BayesUpdaterClassifierEditor.this.showClassification();
                        BayesUpdaterClassifierEditor.this.showRocCurve();
                        BayesUpdaterClassifierEditor.this.showConfusionMatrix();
                    }
                };
            }
        });
        List<Node> nodes = this.getClassifier().getBayesImVars();
        Node[] variables = nodes.toArray(new Node[0]);
        this.variableDropdown = new JComboBox<Node>(variables);
        this.getVariableDropdown().setBackground(Color.WHITE);
        this.getVariableDropdown().setMaximumSize(new Dimension(200, 50));
        DiscreteVariable variable = (DiscreteVariable)this.getVariableDropdown().getSelectedItem();
        this.categoryDropdown = new JComboBox<String>(variable.getCategories().toArray(new String[0]));
        this.getCategoryDropdown().setBackground(Color.WHITE);
        this.getCategoryDropdown().setMaximumSize(new Dimension(200, 50));
        this.variableDropdown.addItemListener(new ItemListener(){

            @Override
            public void itemStateChanged(ItemEvent e) {
                JComboBox comboBox = (JComboBox)e.getSource();
                Object selectedItem = comboBox.getSelectedItem();
                DiscreteVariable variable = (DiscreteVariable)selectedItem;
                List<String> categories = variable.getCategories();
                DefaultComboBoxModel<String> newModel = new DefaultComboBoxModel<String>(categories.toArray(new String[0]));
                BayesUpdaterClassifierEditor.this.getCategoryDropdown().setModel(newModel);
            }
        });
        this.categoryDropdown.addItemListener(new ItemListener(){

            @Override
            public void itemStateChanged(ItemEvent e) {
                BayesUpdaterClassifierEditor.this.showRocCurve();
            }
        });
        Box toolbar = Box.createVerticalBox();
        Box row1 = Box.createHorizontalBox();
        row1.add(Box.createHorizontalStrut(5));
        row1.add(new JLabel("Target = "));
        row1.add(this.getVariableDropdown());
        row1.add(Box.createHorizontalStrut(5));
        row1.add(new JLabel("Category for ROC ="));
        row1.add(this.getCategoryDropdown());
        row1.add(Box.createHorizontalStrut(10));
        row1.add(classifyButton);
        row1.add(Box.createHorizontalGlue());
        toolbar.add(row1);
        toolbar.add(Box.createVerticalStrut(5));
        toolbar.setBorder(new EmptyBorder(2, 2, 2, 2));
        return toolbar;
    }

    private void doClassify() {
        DiscreteVariable variable = (DiscreteVariable)this.getVariableDropdown().getSelectedItem();
        String varName = variable.getName();
        String category = (String)this.getCategoryDropdown().getSelectedItem();
        int catIndex = variable.getIndex(category);
        this.getClassifier().setTarget(varName);
        this.getClassifier().classify();
    }

    private void showClassification() {
        int tabIndex = -1;
        for (int i = 0; i < this.getTabbedPane().getTabCount(); ++i) {
            if (!"Classification".equals(this.getTabbedPane().getTitleAt(i))) continue;
            this.getTabbedPane().remove(i);
            tabIndex = i;
        }
        int[] classifications = this.getClassifier().getClassifications();
        double[][] marginals = this.getClassifier().getMarginals();
        int maxCategory = 0;
        for (int classification : classifications) {
            if (classification <= maxCategory) continue;
            maxCategory = classification;
        }
        LinkedList<Node> variables = new LinkedList<Node>();
        DiscreteVariable targetVariable = this.classifier.getTargetVariable();
        DiscreteVariable classVar = new DiscreteVariable(targetVariable.getName(), maxCategory + 1);
        variables.add(classVar);
        for (int i = 0; i < marginals.length; ++i) {
            String name = "P(" + targetVariable + "=" + i + ")";
            ContinuousVariable scoreVar = new ContinuousVariable(name);
            variables.add(scoreVar);
        }
        classVar.setName("Result");
        BoxDataSet dataSet = new BoxDataSet(new DoubleDataBox(classifications.length, variables.size()), variables);
        for (int i = 0; i < classifications.length; ++i) {
            dataSet.setInt(i, 0, classifications[i]);
            for (int j = 0; j < marginals.length; ++j) {
                dataSet.setDouble(i, j + 1, marginals[j][i]);
            }
        }
        DataDisplay jTable = new DataDisplay(dataSet);
        JScrollPane scroll = new JScrollPane(jTable);
        if (tabIndex == -1) {
            this.getTabbedPane().add("Classification", scroll);
        } else {
            this.getTabbedPane().add((Component)scroll, tabIndex);
            this.getTabbedPane().setTitleAt(tabIndex, "Classification");
        }
    }

    private void showRocCurve() {
        RocPlot plot;
        int tabIndex = -1;
        for (int i = 0; i < this.getTabbedPane().getTabCount(); ++i) {
            if (!"ROC Plot".equals(this.getTabbedPane().getTitleAt(i))) continue;
            this.getTabbedPane().remove(i);
            tabIndex = i;
            this.rocPlot = null;
            this.saveRoc.setEnabled(false);
        }
        double[][] marginals = this.getClassifier().getMarginals();
        int ncases = this.getClassifier().getNumCases();
        boolean[] inCategory = new boolean[ncases];
        DataSet testData = this.getClassifier().getTestData();
        DiscreteVariable targetVariable = this.classifier.getTargetVariable();
        String targetName = targetVariable.getName();
        Node variable2 = testData.getVariable(targetName);
        int varIndex = testData.getVariables().indexOf(variable2);
        if (varIndex == -1) {
            return;
        }
        String category = (String)this.getCategoryDropdown().getSelectedItem();
        DiscreteVariable variable = (DiscreteVariable)variable2;
        int catIndex = variable.getIndex(category);
        for (int i = 0; i < inCategory.length; ++i) {
            inCategory[i] = testData.getInt(i, varIndex) == catIndex;
        }
        double[] scores = marginals[catIndex];
        RocCalculator rocc = new RocCalculator(scores, inCategory, 0);
        double[][] points = rocc.getScaledRocPlot();
        double area = rocc.getAuc();
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        String info = "AUC = " + nf.format(area);
        String title = "ROC Plot, " + this.classifier.getTargetVariable() + " = " + category;
        this.rocPlot = plot = new RocPlot(points, title, info);
        this.saveRoc.setEnabled(true);
        if (tabIndex == -1) {
            this.getTabbedPane().add("ROC Plot", plot);
        } else {
            this.getTabbedPane().add((Component)plot, tabIndex);
            this.getTabbedPane().setTitleAt(tabIndex, "ROC Plot");
        }
    }

    private void showConfusionMatrix() {
        int i;
        int tabIndex = -1;
        for (int i2 = 0; i2 < this.getTabbedPane().getTabCount(); ++i2) {
            if (!"Confusion Matrix".equals(this.getTabbedPane().getTitleAt(i2))) continue;
            this.getTabbedPane().remove(i2);
            tabIndex = i2;
        }
        StringBuilder buf = new StringBuilder();
        int[][] crossTabs = this.getClassifier().crossTabulation();
        if (crossTabs == null) {
            return;
        }
        DiscreteVariable targetVariable = this.getClassifier().getTargetVariable();
        int nvalues = targetVariable.getNumCategories();
        int ncases = this.getClassifier().getNumCases();
        int ntot = this.getClassifier().getTotalUsableCases();
        buf.append("Total number of usable cases = ");
        buf.append(ntot);
        buf.append(" out of ");
        buf.append(ncases);
        buf.append("\n\nTarget Variable ");
        buf.append(targetVariable);
        buf.append("\n\t\tEstimated\t");
        buf.append("\nObserved\t");
        for (i = 0; i < nvalues - 1; ++i) {
            buf.append(targetVariable.getCategory(i));
            buf.append("\t");
        }
        buf.append(targetVariable.getCategory(nvalues - 1));
        for (i = 0; i < nvalues; ++i) {
            buf.append("\n");
            buf.append(targetVariable.getCategory(i));
            buf.append("\t");
            for (int j = 0; j < nvalues - 1; ++j) {
                buf.append(crossTabs[i][j]);
                buf.append("\t");
            }
            buf.append(crossTabs[i][nvalues - 1]);
        }
        buf.append("\n\nPercentage correctly classified:  ");
        buf.append(this.getClassifier().getPercentCorrect());
        JTextArea label = new JTextArea(buf.toString());
        label.setFont(new Font("SansSerif", 0, 14));
        JPanel panel = new JPanel();
        panel.setLayout(new BorderLayout());
        panel.setBackground(Color.WHITE);
        Box b1 = Box.createVerticalBox();
        Box b2 = Box.createHorizontalBox();
        b2.add(Box.createHorizontalStrut(5));
        b2.add(label);
        b2.add(Box.createHorizontalGlue());
        b1.add(b2);
        b1.add(Box.createVerticalGlue());
        b1.add(Box.createVerticalGlue());
        panel.add((Component)b1, "Center");
        JScrollPane scroll = new JScrollPane(panel);
        if (tabIndex == -1) {
            this.getTabbedPane().add("Confusion Matrix", scroll);
        } else {
            this.getTabbedPane().add((Component)scroll, tabIndex);
            this.getTabbedPane().setTitleAt(tabIndex, "Confusion Matrix");
        }
    }

    public BayesUpdaterClassifierEditor(BayesUpdaterClassifierWrapper wrapper) {
        this(wrapper.getClassifier());
    }

    private BayesUpdaterClassifier getClassifier() {
        return this.classifier;
    }

    private JComboBox getVariableDropdown() {
        return this.variableDropdown;
    }

    private JTabbedPane getTabbedPane() {
        return this.tabbedPane;
    }

    private JComboBox getCategoryDropdown() {
        return this.categoryDropdown;
    }
}

