/*
 * Decompiled with CFR 0.152.
 */
package edu.pitt.csb.mgm;

import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.LogisticRegression;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.IndependenceResult;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.SearchLogUtils;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.ProbUtils;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;

public class IndTestMultinomialLogisticRegressionWald
implements IndependenceTest {
    private final DataSet originalData;
    private final List<Node> searchVariables;
    private final DataSet internalData;
    private double alpha;
    private double lastP;
    private final Map<Node, List<Node>> variablesPerNode = new HashMap<Node, List<Node>>();
    private final LogisticRegression logisticRegression;
    private final RegressionDataset regression;
    private final boolean preferLinear;
    private boolean verbose;

    public IndTestMultinomialLogisticRegressionWald(DataSet data, double alpha, boolean preferLinear) {
        if (!(alpha >= 0.0) || !(alpha <= 1.0)) {
            throw new IllegalArgumentException("Alpha mut be in [0, 1]");
        }
        this.searchVariables = data.getVariables();
        this.originalData = data.copy();
        DataSet internalData = data.copy();
        this.alpha = alpha;
        this.preferLinear = preferLinear;
        List<Node> variables = internalData.getVariables();
        for (Node node : variables) {
            List<Node> nodes = this.expandVariable(internalData, node);
            this.variablesPerNode.put(node, nodes);
        }
        this.internalData = internalData;
        this.logisticRegression = new LogisticRegression(internalData);
        this.regression = new RegressionDataset(internalData);
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        throw new UnsupportedOperationException();
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, List<Node> z) {
        if (x instanceof DiscreteVariable && y instanceof DiscreteVariable) {
            return this.isIndependentMultinomialLogisticRegression(x, y, z);
        }
        if (!this.preferLinear) {
            if (x instanceof DiscreteVariable) {
                return this.isIndependentMultinomialLogisticRegression(x, y, z);
            }
            if (y instanceof DiscreteVariable) {
                return this.isIndependentMultinomialLogisticRegression(y, x, z);
            }
            return this.isIndependentRegression(x, y, z);
        }
        if (x instanceof DiscreteVariable) {
            return this.isIndependentRegression(y, x, z);
        }
        return this.isIndependentRegression(x, y, z);
    }

    private List<Node> expandVariable(DataSet dataSet, Node node) {
        if (node instanceof ContinuousVariable) {
            return Collections.singletonList(node);
        }
        if (node instanceof DiscreteVariable && ((DiscreteVariable)node).getNumCategories() < 3) {
            return Collections.singletonList(node);
        }
        if (!(node instanceof DiscreteVariable)) {
            throw new IllegalArgumentException();
        }
        ArrayList<String> varCats = new ArrayList<String>(((DiscreteVariable)node).getCategories());
        varCats.remove(0);
        ArrayList<Node> variables = new ArrayList<Node>();
        for (String cat : varCats) {
            String newVarName;
            DiscreteVariable newVar;
            while (dataSet.getVariable((newVar = new DiscreteVariable(newVarName = node.getName() + "MULTINOM." + cat, 2)).getName()) != null) {
            }
            variables.add(newVar);
            dataSet.addVariable(newVar);
            int newVarIndex = dataSet.getColumn(newVar);
            int numCases = dataSet.getNumRows();
            for (int l = 0; l < numCases; ++l) {
                Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
                int dataCellIndex = ((DiscreteVariable)node).getIndex(dataCell.toString());
                if (dataCellIndex == ((DiscreteVariable)node).getIndex(cat)) {
                    dataSet.setInt(l, newVarIndex, 1);
                    continue;
                }
                dataSet.setInt(l, newVarIndex, 0);
            }
        }
        return variables;
    }

    private IndependenceResult isIndependentMultinomialLogisticRegression(Node x, Node y, List<Node> z) {
        if (!this.variablesPerNode.containsKey(x)) {
            throw new IllegalArgumentException("Unrecogized node: " + x);
        }
        if (!this.variablesPerNode.containsKey(y)) {
            throw new IllegalArgumentException("Unrecogized node: " + y);
        }
        for (Node node : z) {
            if (this.variablesPerNode.containsKey(x)) continue;
            throw new IllegalArgumentException("Unrecogized node: " + node);
        }
        ArrayList pValues = new ArrayList();
        int[] _rows = this.getNonMissingRows(x, y, z);
        this.logisticRegression.setRows(_rows);
        double p = 1.0;
        for (Node _x : this.variablesPerNode.get(x)) {
            ArrayList<Node> regressors1 = new ArrayList<Node>((Collection)this.variablesPerNode.get(y));
            for (Node _z : z) {
                regressors1.addAll((Collection<Node>)this.variablesPerNode.get(_z));
            }
            LogisticRegression.Result result1 = this.logisticRegression.regress((DiscreteVariable)_x, regressors1);
            int n = this.originalData.getNumRows();
            int k = regressors1.size() + 1;
            for (int i = 0; i < this.variablesPerNode.get(y).size(); ++i) {
                double wald = FastMath.abs(result1.getCoefs()[i + 1] / result1.getStdErrs()[i + 1]);
                double val = (1.0 - ProbUtils.tCdf(wald, n - k)) * 2.0;
                if (val < p) {
                    p = val;
                }
                if (!(p <= this.alpha)) continue;
                boolean independent = false;
                this.lastP = p;
                if (independent) {
                    TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, p));
                } else {
                    TetradLogger.getInstance().log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, p));
                }
                return new IndependenceResult(new IndependenceFact(x, y, z), independent, p);
            }
        }
        boolean independent = p > this.alpha;
        this.lastP = p;
        if (this.verbose && independent) {
            TetradLogger.getInstance().forceLogMessage(SearchLogUtils.independenceFactMsg(x, y, z, p));
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), independent, p);
    }

    private int[] getNonMissingRows(Node x, Node y, List<Node> z) {
        int[] _rows = new int[this.internalData.getNumRows()];
        for (int k = 0; k < _rows.length; ++k) {
            _rows[k] = k;
        }
        return _rows;
    }

    private boolean isMissing(Node x, int i) {
        int v;
        int j = this.internalData.getColumn(x);
        if (x instanceof DiscreteVariable && (v = this.internalData.getInt(i, j)) == -99) {
            return true;
        }
        if (x instanceof ContinuousVariable) {
            double v2 = this.internalData.getDouble(i, j);
            return Double.isNaN(v2);
        }
        return false;
    }

    private IndependenceResult isIndependentRegression(Node x, Node y, List<Node> z) {
        boolean independent;
        RegressionResult result;
        if (!this.variablesPerNode.containsKey(x)) {
            throw new IllegalArgumentException("Unrecogized node: " + x);
        }
        if (!this.variablesPerNode.containsKey(y)) {
            throw new IllegalArgumentException("Unrecogized node: " + y);
        }
        for (Node node : z) {
            if (this.variablesPerNode.containsKey(node)) continue;
            throw new IllegalArgumentException("Unrecogized node: " + node);
        }
        ArrayList<Node> regressors = new ArrayList<Node>();
        if (y instanceof ContinuousVariable) {
            regressors.add(this.internalData.getVariable(y.getName()));
        } else {
            regressors.addAll((Collection)this.variablesPerNode.get(y));
        }
        for (Node _z : z) {
            regressors.addAll((Collection)this.variablesPerNode.get(_z));
        }
        int[] nArray = this.getNonMissingRows(x, y, z);
        this.regression.setRows(nArray);
        try {
            result = this.regression.regress(x, regressors);
        }
        catch (Exception e) {
            return new IndependenceResult(new IndependenceFact(x, y, z), false, Double.NaN);
        }
        double p = 1.0;
        if (y instanceof ContinuousVariable) {
            p = result.getP()[1];
        } else {
            for (int i = 0; i < this.variablesPerNode.get(y).size(); ++i) {
                double val = result.getP()[1 + i];
                if (!(val < p)) continue;
                p = val;
            }
        }
        this.lastP = p;
        boolean bl = independent = p > this.alpha;
        if (this.verbose && independent) {
            TetradLogger.getInstance().forceLogMessage(SearchLogUtils.independenceFactMsg(x, y, z, p));
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), independent, p);
    }

    public double getPValue() {
        return this.lastP;
    }

    @Override
    public List<Node> getVariables() {
        return this.searchVariables;
    }

    @Override
    public List<String> getVariableNames() {
        List<Node> variables = this.getVariables();
        ArrayList<String> variableNames = new ArrayList<String>();
        for (Node variable1 : variables) {
            variableNames.add(variable1.getName());
        }
        return variableNames;
    }

    @Override
    public Node getVariable(String name) {
        for (int i = 0; i < this.getVariables().size(); ++i) {
            Node variable = this.getVariables().get(i);
            if (!variable.getName().equals(name)) continue;
            return variable;
        }
        return null;
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        return false;
    }

    @Override
    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public DataSet getData() {
        return this.originalData;
    }

    @Override
    public ICovarianceMatrix getCov() {
        return null;
    }

    @Override
    public List<DataSet> getDataSets() {
        return null;
    }

    @Override
    public int getSampleSize() {
        return 0;
    }

    @Override
    public List<Matrix> getCovMatrices() {
        return null;
    }

    @Override
    public double getScore() {
        return this.getPValue();
    }

    @Override
    public String toString() {
        DecimalFormat nf = new DecimalFormat("0.0000");
        return "Multinomial Logistic Regression, alpha = " + nf.format(this.getAlpha());
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }
}

