/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.work_in_progress;

import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.LogisticRegression;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;

public class IndTestMultinomialLogisticRegression
implements IndependenceTest {
    private final DataSet originalData;
    private final List<Node> searchVariables;
    private final DataSet internalData;
    private final Map<Node, List<Node>> variablesPerNode = new HashMap<Node, List<Node>>();
    private final LogisticRegression logisticRegression;
    private final RegressionDataset regression;
    int[] _rows;
    private double alpha;
    private double lastP;
    private boolean verbose;

    public IndTestMultinomialLogisticRegression(DataSet data, double alpha) {
        this.searchVariables = data.getVariables();
        this.originalData = data.copy();
        DataSet internalData = data.copy();
        this.alpha = alpha;
        List<Node> variables = internalData.getVariables();
        for (Node node : variables) {
            List<Node> nodes = this.expandVariable(internalData, node);
            this.variablesPerNode.put(node, nodes);
        }
        this.internalData = internalData;
        this.logisticRegression = new LogisticRegression(internalData);
        this.regression = new RegressionDataset(internalData);
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        throw new UnsupportedOperationException();
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, Set<Node> z) {
        if (x instanceof DiscreteVariable) {
            return this.isIndependentMultinomialLogisticRegression(x, y, z);
        }
        if (y instanceof DiscreteVariable) {
            return this.isIndependentMultinomialLogisticRegression(y, x, z);
        }
        return this.isIndependentRegression(x, y, z);
    }

    private List<Node> expandVariable(DataSet dataSet, Node node) {
        if (node instanceof ContinuousVariable) {
            return Collections.singletonList(node);
        }
        if (node instanceof DiscreteVariable && ((DiscreteVariable)node).getNumCategories() < 3) {
            return Collections.singletonList(node);
        }
        if (!(node instanceof DiscreteVariable)) {
            throw new IllegalArgumentException();
        }
        ArrayList<String> varCats = new ArrayList<String>(((DiscreteVariable)node).getCategories());
        varCats.remove(0);
        ArrayList<Node> variables = new ArrayList<Node>();
        for (String cat : varCats) {
            String newVarName;
            DiscreteVariable newVar;
            while (dataSet.getVariable((newVar = new DiscreteVariable(newVarName = node.getName() + "MULTINOM." + cat, 2)).getName()) != null) {
            }
            variables.add(newVar);
            dataSet.addVariable(newVar);
            int newVarIndex = dataSet.getColumn(newVar);
            int numCases = dataSet.getNumRows();
            for (int l = 0; l < numCases; ++l) {
                Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
                int dataCellIndex = ((DiscreteVariable)node).getIndex(dataCell.toString());
                if (dataCellIndex == ((DiscreteVariable)node).getIndex(cat)) {
                    dataSet.setInt(l, newVarIndex, 1);
                    continue;
                }
                dataSet.setInt(l, newVarIndex, 0);
            }
        }
        return variables;
    }

    private IndependenceResult isIndependentMultinomialLogisticRegression(Node x, Node y, Set<Node> z) {
        Object regressors0;
        if (!this.variablesPerNode.containsKey(x)) {
            throw new IllegalArgumentException("Unrecogized node: " + x);
        }
        if (!this.variablesPerNode.containsKey(y)) {
            throw new IllegalArgumentException("Unrecogized node: " + y);
        }
        for (Node node : z) {
            if (this.variablesPerNode.containsKey(x)) continue;
            throw new IllegalArgumentException("Unrecogized node: " + node);
        }
        ArrayList<Double> pValues = new ArrayList<Double>();
        int[] _rows = this.getNonMissingRows();
        this.logisticRegression.setRows(_rows);
        for (Node _x : this.variablesPerNode.get(x)) {
            regressors0 = new ArrayList();
            for (Node _z : z) {
                regressors0.addAll((Collection)this.variablesPerNode.get(_z));
            }
            LogisticRegression.Result result0 = this.logisticRegression.regress((DiscreteVariable)_x, (List<Node>)regressors0);
            ArrayList<Node> regressors1 = new ArrayList<Node>((Collection)this.variablesPerNode.get(y));
            for (Node _z : z) {
                regressors1.addAll((Collection<Node>)this.variablesPerNode.get(_z));
            }
            LogisticRegression.Result result1 = this.logisticRegression.regress((DiscreteVariable)_x, regressors1);
            double ll0 = result0.getLogLikelihood();
            double ll1 = result1.getLogLikelihood();
            double chisq = ll0 - ll1;
            int df = this.variablesPerNode.get(y).size();
            double p = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(chisq);
            if (Double.isNaN(p)) {
                throw new RuntimeException("Undefined p-value encountered when testing " + LogUtilsSearch.independenceFact(x, y, z));
            }
            pValues.add(p);
        }
        double p = 1.0;
        regressors0 = pValues.iterator();
        while (regressors0.hasNext()) {
            double val = (Double)regressors0.next();
            if (!(val < p)) continue;
            p = val;
        }
        boolean independent = p > this.alpha;
        this.lastP = p;
        if (this.verbose && independent) {
            TetradLogger.getInstance().forceLogMessage(LogUtilsSearch.independenceFactMsg(x, y, z, p));
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), independent, p, this.alpha - p);
    }

    private int[] getNonMissingRows() {
        if (this._rows == null) {
            this._rows = new int[this.internalData.getNumRows()];
            for (int k = 0; k < this._rows.length; ++k) {
                this._rows[k] = k;
            }
        }
        return this._rows;
    }

    private IndependenceResult isIndependentRegression(Node x, Node y, Set<Node> z) {
        boolean indep;
        double p;
        RegressionResult result;
        if (!this.variablesPerNode.containsKey(x)) {
            throw new IllegalArgumentException("Unrecogized node: " + x);
        }
        if (!this.variablesPerNode.containsKey(y)) {
            throw new IllegalArgumentException("Unrecogized node: " + y);
        }
        for (Node node : z) {
            if (this.variablesPerNode.containsKey(x)) continue;
            throw new IllegalArgumentException("Unrecogized node: " + node);
        }
        ArrayList<Node> regressors = new ArrayList<Node>();
        regressors.add(this.internalData.getVariable(y.getName()));
        for (Node _z : z) {
            regressors.addAll((Collection)this.variablesPerNode.get(_z));
        }
        int[] nArray = this.getNonMissingRows();
        this.regression.setRows(nArray);
        try {
            result = this.regression.regress(x, regressors);
        }
        catch (Exception e) {
            return new IndependenceResult(new IndependenceFact(x, y, z), false, Double.NaN, Double.NaN);
        }
        this.lastP = p = result.getP()[1];
        if (Double.isNaN(p)) {
            throw new RuntimeException("Undefined p-value encountered when testing " + LogUtilsSearch.independenceFact(x, y, z));
        }
        boolean bl = indep = p > this.alpha;
        if (this.verbose) {
            if (indep) {
                TetradLogger.getInstance().log("independencies", LogUtilsSearch.independenceFactMsg(x, y, z, p));
            } else {
                TetradLogger.getInstance().log("dependencies", LogUtilsSearch.dependenceFactMsg(x, y, z, p));
            }
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), indep, p, this.alpha - p);
    }

    public double getPValue() {
        return this.lastP;
    }

    @Override
    public List<Node> getVariables() {
        return this.searchVariables;
    }

    public boolean determines(List<Node> z, Node y) {
        return false;
    }

    @Override
    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public DataSet getData() {
        return this.originalData;
    }

    @Override
    public String toString() {
        DecimalFormat nf = new DecimalFormat("0.0000");
        return "Multinomial Logistic Regression, alpha = " + nf.format(this.getAlpha());
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }
}

