/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.test;

import edu.cmu.tetrad.data.CellTable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.util.CombinationIterator;
import edu.cmu.tetrad.util.ProbUtils;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.math3.util.FastMath;

public final class GSquareTest {
    private final DataSet dataSet;
    private final int[] dims;
    private final CellTable cellTable;
    private double alpha;

    public GSquareTest(DataSet dataSet, double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Significance level must be in [0, 1]: " + alpha);
        }
        this.dims = new int[dataSet.getNumColumns()];
        for (int i = 0; i < this.getDims().length; ++i) {
            DiscreteVariable variable = (DiscreteVariable)dataSet.getVariable(i);
            this.getDims()[i] = variable.getNumCategories();
        }
        this.dataSet = dataSet;
        this.alpha = alpha;
        this.cellTable = new CellTable(null);
        this.getCellTable().setMissingValue(-99);
    }

    public Result calcGSquare(int[] testIndices) {
        double pValue;
        if (testIndices.length < 2) {
            throw new IllegalArgumentException("Need at least two variables for G Square test.");
        }
        this.getCellTable().addToTable(this.getDataSet(), testIndices);
        int[] firstVar = new int[]{0};
        int[] secondVar = new int[]{1};
        int[] bothVars = new int[]{0, 1};
        double g2 = 0.0;
        int df = 0;
        int[] condDims = new int[testIndices.length - 2];
        System.arraycopy(this.selectFromArray(this.getDims(), testIndices), 2, condDims, 0, condDims.length);
        int[] coords = new int[testIndices.length];
        int numRows = this.getCellTable().getNumValues(0);
        int numCols = this.getCellTable().getNumValues(1);
        boolean[] attestedRows = new boolean[numRows];
        boolean[] attestedCols = new boolean[numCols];
        CombinationIterator combinationIterator = new CombinationIterator(condDims);
        while (combinationIterator.hasNext()) {
            int i;
            int[] combination = combinationIterator.next();
            System.arraycopy(combination, 0, coords, 2, combination.length);
            Arrays.fill(attestedRows, true);
            Arrays.fill(attestedCols, true);
            long total = this.getCellTable().calcMargin(coords, bothVars);
            double _gSquare = 0.0;
            ArrayList<Double> e = new ArrayList<Double>();
            ArrayList<Long> o = new ArrayList<Long>();
            for (i = 0; i < numRows; ++i) {
                for (int j = 0; j < numCols; ++j) {
                    coords[0] = i;
                    coords[1] = j;
                    long sumRow = this.getCellTable().calcMargin(coords, secondVar);
                    long sumCol = this.getCellTable().calcMargin(coords, firstVar);
                    long observed = (int)this.getCellTable().getValue(coords);
                    boolean skip = false;
                    if (sumRow == 0L) {
                        attestedRows[i] = false;
                        skip = true;
                    }
                    if (sumCol == 0L) {
                        attestedCols[j] = false;
                        skip = true;
                    }
                    if (skip) continue;
                    e.add((double)sumCol * (double)sumRow);
                    o.add(observed);
                }
            }
            for (i = 0; i < o.size(); ++i) {
                double expected = (Double)e.get(i) / (double)total;
                if ((Long)o.get(i) == 0L) continue;
                _gSquare += 2.0 * (double)((Long)o.get(i)).longValue() * FastMath.log((double)((Long)o.get(i)).longValue() / expected);
            }
            if (total == 0L) continue;
            int numAttestedRows = 0;
            int numAttestedCols = 0;
            for (boolean attestedRow : attestedRows) {
                if (!attestedRow) continue;
                ++numAttestedRows;
            }
            for (boolean attestedCol : attestedCols) {
                if (!attestedCol) continue;
                ++numAttestedCols;
            }
            int _df = (numAttestedRows - 1) * (numAttestedCols - 1);
            if (_df <= 0) continue;
            df += _df;
            g2 += _gSquare;
        }
        if (df == 0) {
            df = 1;
        }
        boolean indep = (pValue = 1.0 - ProbUtils.chisqCdf(g2, df)) > this.getAlpha();
        return new Result(g2, pValue, df, indep);
    }

    public int[] getDims() {
        return this.dims;
    }

    public CellTable getCellTable() {
        return this.cellTable;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Significance level must be in [0, 1]: " + alpha);
        }
        this.alpha = alpha;
    }

    public DataSet getDataSet() {
        return this.dataSet;
    }

    public boolean isDetermined(int[] testIndices, double p) {
        this.getCellTable().addToTable(this.getDataSet(), testIndices);
        int[] firstVar = new int[]{0};
        int[] condDims = new int[testIndices.length - 1];
        System.arraycopy(this.selectFromArray(this.getDims(), testIndices), 1, condDims, 0, condDims.length);
        int[] coords = new int[testIndices.length];
        int numValues = this.getCellTable().getNumValues(0);
        CombinationIterator combinationIterator = new CombinationIterator(condDims);
        while (combinationIterator.hasNext()) {
            int[] combination = combinationIterator.next();
            System.arraycopy(combination, 0, coords, 1, combination.length);
            long total = this.getCellTable().calcMargin(coords, firstVar);
            if (total == 0L) continue;
            boolean dominates = false;
            for (int i = 0; i < numValues; ++i) {
                coords[0] = i;
                long numi = this.getCellTable().getValue(coords);
                if (!((double)numi / (double)total >= p)) continue;
                dominates = true;
            }
            if (dominates) continue;
            return false;
        }
        return true;
    }

    private int[] selectFromArray(int[] arr, int[] indices) {
        int[] retArr = new int[indices.length];
        for (int i = 0; i < indices.length; ++i) {
            retArr[i] = arr[indices[i]];
        }
        return retArr;
    }

    public static final class Result {
        private final double gSquare;
        private final double pValue;
        private final int df;
        private final boolean isIndep;

        public Result(double gSquare, double pValue, int df, boolean isIndep) {
            this.gSquare = gSquare;
            this.pValue = pValue;
            this.df = df;
            this.isIndep = isIndep;
        }

        public double getGSquare() {
            return this.gSquare;
        }

        public double getPValue() {
            return this.pValue;
        }

        public int getDf() {
            return this.df;
        }

        public boolean isIndep() {
            return this.isIndep;
        }
    }
}

