/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.work_in_progress;

import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.LogisticRegression;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.ProbUtils;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.util.FastMath;

public class IndTestMixedMultipleTTest
implements IndependenceTest {
    private final DataSet originalData;
    private final List<Node> searchVariables;
    private final DataSet internalData;
    private final Map<Node, List<Node>> variablesPerNode = new HashMap<Node, List<Node>>();
    private final LogisticRegression logisticRegression;
    private final RegressionDataset regression;
    int[] _rows;
    private double alpha;
    private double lastP;
    private boolean verbose;
    private boolean preferLinear = true;

    public IndTestMixedMultipleTTest(DataSet data, double alpha) {
        this.searchVariables = data.getVariables();
        this.originalData = data.copy();
        DataSet internalData = data.copy();
        this.alpha = alpha;
        List<Node> variables = internalData.getVariables();
        for (Node node : variables) {
            List<Node> nodes = this.expandVariable(internalData, node);
            this.variablesPerNode.put(node, nodes);
        }
        this.internalData = internalData;
        this.logisticRegression = new LogisticRegression(internalData);
        this.regression = new RegressionDataset(internalData);
    }

    public void setPreferLinear(boolean preferLinear) {
        this.preferLinear = preferLinear;
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        throw new UnsupportedOperationException();
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, Set<Node> z) {
        if (x instanceof DiscreteVariable && y instanceof DiscreteVariable) {
            return this.isIndependentMultinomialLogisticRegression(x, y, z);
        }
        if (x instanceof DiscreteVariable) {
            if (this.preferLinear) {
                return this.isIndependentRegression(y, x, z);
            }
            return this.isIndependentMultinomialLogisticRegression(x, y, z);
        }
        if (y instanceof DiscreteVariable && !this.preferLinear) {
            return this.isIndependentMultinomialLogisticRegression(y, x, z);
        }
        return this.isIndependentRegression(x, y, z);
    }

    public double getPValue() {
        return this.lastP;
    }

    @Override
    public List<Node> getVariables() {
        return this.searchVariables;
    }

    public boolean determines(List<Node> z, Node y) {
        throw new UnsupportedOperationException("Method not implemented.");
    }

    @Override
    public double getAlpha() {
        throw new UnsupportedOperationException("Method not implemented.");
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public DataSet getData() {
        return this.originalData;
    }

    @Override
    public String toString() {
        DecimalFormat nf = new DecimalFormat("0.0000");
        return "Multinomial Logistic Regression, alpha = " + nf.format(this.getAlpha());
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    private List<Node> expandVariable(DataSet dataSet, Node node) {
        if (node instanceof ContinuousVariable) {
            return Collections.singletonList(node);
        }
        if (node instanceof DiscreteVariable && ((DiscreteVariable)node).getNumCategories() < 3) {
            return Collections.singletonList(node);
        }
        if (!(node instanceof DiscreteVariable)) {
            throw new IllegalArgumentException();
        }
        ArrayList<String> varCats = new ArrayList<String>(((DiscreteVariable)node).getCategories());
        varCats.remove(0);
        ArrayList<Node> variables = new ArrayList<Node>();
        for (String cat : varCats) {
            String newVarName;
            DiscreteVariable newVar;
            while (dataSet.getVariable((newVar = new DiscreteVariable(newVarName = node.getName() + "MULTINOM." + cat, 2)).getName()) != null) {
            }
            variables.add(newVar);
            dataSet.addVariable(newVar);
            int newVarIndex = dataSet.getColumn(newVar);
            int numCases = dataSet.getNumRows();
            for (int l = 0; l < numCases; ++l) {
                Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
                int dataCellIndex = ((DiscreteVariable)node).getIndex(dataCell.toString());
                if (dataCellIndex == ((DiscreteVariable)node).getIndex(cat)) {
                    dataSet.setInt(l, newVarIndex, 1);
                    continue;
                }
                dataSet.setInt(l, newVarIndex, 0);
            }
        }
        return variables;
    }

    private double[] dependencePvalsLogit(Node x, Node y, Set<Node> z) {
        if (!this.variablesPerNode.containsKey(x)) {
            throw new IllegalArgumentException("Unrecogized node: " + x);
        }
        if (!this.variablesPerNode.containsKey(y)) {
            throw new IllegalArgumentException("Unrecogized node: " + y);
        }
        for (Node node : z) {
            if (this.variablesPerNode.containsKey(node)) continue;
            throw new IllegalArgumentException("Unrecogized node: " + node);
        }
        int[] _rows = this.getNonMissingRows();
        this.logisticRegression.setRows(_rows);
        ArrayList<Node> yzList = new ArrayList<Node>();
        yzList.add(y);
        yzList.addAll(z);
        ArrayList<Node> yzDumList = new ArrayList<Node>((Collection)this.variablesPerNode.get(y));
        for (Node _z : z) {
            yzDumList.addAll((Collection)this.variablesPerNode.get(_z));
        }
        double[] sumLnP = new double[yzList.size()];
        for (int i = 0; i < this.variablesPerNode.get(x).size(); ++i) {
            Node _x = this.variablesPerNode.get(x).get(i);
            LogisticRegression.Result result1 = this.logisticRegression.regress((DiscreteVariable)_x, yzDumList);
            int n = this.originalData.getNumRows();
            int k = yzDumList.size();
            int coefIndex = 1;
            for (int j = 0; j < yzList.size(); ++j) {
                for (int dum = 0; dum < this.variablesPerNode.get(yzList.get(j)).size(); ++dum) {
                    double wald = FastMath.abs(result1.getCoefs()[coefIndex] / result1.getStdErrs()[coefIndex]);
                    double val = (1.0 - ProbUtils.tCdf(wald, n - k)) * 2.0;
                    int n2 = j;
                    sumLnP[n2] = sumLnP[n2] + FastMath.log(val);
                    ++coefIndex;
                }
            }
        }
        double[] pVec = new double[sumLnP.length];
        for (int i = 0; i < pVec.length; ++i) {
            if (sumLnP[i] == Double.NEGATIVE_INFINITY) {
                pVec[i] = 0.0;
                continue;
            }
            int df = 2 * this.variablesPerNode.get(x).size() * this.variablesPerNode.get(yzList.get(i)).size();
            pVec[i] = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(-2.0 * sumLnP[i]);
        }
        return pVec;
    }

    private IndependenceResult isIndependentMultinomialLogisticRegression(Node x, Node y, Set<Node> z) {
        double p = this.dependencePvalsLogit(x, y, z)[0];
        boolean independent = p > this.alpha;
        this.lastP = p;
        if (this.verbose && independent) {
            TetradLogger.getInstance().forceLogMessage(LogUtilsSearch.independenceFactMsg(x, y, z, this.getPValue()));
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), independent, p, this.alpha - p);
    }

    private int[] getNonMissingRows() {
        if (this._rows == null) {
            this._rows = new int[this.internalData.getNumRows()];
            for (int k = 0; k < this._rows.length; ++k) {
                this._rows[k] = k;
            }
        }
        return this._rows;
    }

    private double[] dependencePvalsLinear(Node x, Node y, Set<Node> z) {
        RegressionResult result;
        if (!this.variablesPerNode.containsKey(x)) {
            throw new IllegalArgumentException("Unrecogized node: " + x);
        }
        if (!this.variablesPerNode.containsKey(y)) {
            throw new IllegalArgumentException("Unrecogized node: " + y);
        }
        for (Node node : z) {
            if (this.variablesPerNode.containsKey(node)) continue;
            throw new IllegalArgumentException("Unrecogized node: " + node);
        }
        ArrayList<Node> yzList = new ArrayList<Node>();
        yzList.add(y);
        yzList.addAll(z);
        ArrayList<Node> yzDumList = new ArrayList<Node>((Collection)this.variablesPerNode.get(y));
        for (Node _z : z) {
            yzDumList.addAll((Collection)this.variablesPerNode.get(_z));
        }
        int[] _rows = this.getNonMissingRows();
        this.regression.setRows(_rows);
        try {
            result = this.regression.regress(x, yzDumList);
        }
        catch (Exception e) {
            return null;
        }
        double[] pVec = new double[yzList.size()];
        double[] pCoef = result.getP();
        int coeffInd = 1;
        for (int i = 0; i < pVec.length; ++i) {
            List<Node> curDummy = this.variablesPerNode.get(yzList.get(i));
            if (curDummy.size() == 1) {
                pVec[i] = pCoef[coeffInd];
                ++coeffInd;
                continue;
            }
            pVec[i] = 0.0;
            for (Node ignored : curDummy) {
                int n = i;
                pVec[n] = pVec[n] + FastMath.log(pCoef[coeffInd]);
                ++coeffInd;
            }
            pVec[i] = pVec[i] == Double.NEGATIVE_INFINITY ? 0.0 : 1.0 - new ChiSquaredDistribution(2 * curDummy.size()).cumulativeProbability(-2.0 * pVec[i]);
        }
        return pVec;
    }

    private IndependenceResult isIndependentRegression(Node x, Node y, Set<Node> z) {
        boolean independent;
        double p;
        this.lastP = p = Objects.requireNonNull(this.dependencePvalsLinear(x, y, z))[0];
        boolean bl = independent = p > this.alpha;
        if (this.verbose && independent) {
            TetradLogger.getInstance().forceLogMessage(LogUtilsSearch.independenceFactMsg(x, y, z, this.getPValue()));
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), independent, p, this.alpha - p);
    }
}

