/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.work_in_progress;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.RandomGraph;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.test.MsepTest;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.CombinationIterator;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.exception.OutOfRangeException;

public final class IndTestUniformScatter
implements IndependenceTest {
    private final DataSet dataSet;
    private final DataSet transformed;
    private final double[][] data;
    private final double alpha;
    private final double avgCountPerCell;
    private final int numCondCategories;
    private boolean verbose = false;

    public IndTestUniformScatter(DataSet dataSet, double alpha, double avgCountPerCell, int numCondCategories) {
        this.alpha = alpha;
        this.dataSet = dataSet;
        this.transformed = IndTestUniformScatter.getUniformTransform(dataSet);
        this.avgCountPerCell = avgCountPerCell;
        this.data = this.transformed.getDoubleData().transpose().toArray();
        this.numCondCategories = numCondCategories;
    }

    public static void main(String ... args) {
        Graph graph = RandomGraph.randomGraph(10, 0, 10, 100, 100, 100, false);
        System.out.println("True graph = " + graph);
        int N = 1000;
        double alpha = 0.001;
        SemPm pm = new SemPm(graph);
        SemIm im = new SemIm(pm);
        DataSet dataSet = im.simulateData(N, false);
        graph = GraphUtils.replaceNodes(graph, dataSet.getVariables());
        MsepTest msepTest = new MsepTest(graph);
        IndTestUniformScatter test = new IndTestUniformScatter(dataSet, alpha, 5.0, 3);
        List<Node> nodes = graph.getNodes();
        int tp = 0;
        int tn = 0;
        int fp = 0;
        int fn = 0;
        for (int x = 0; x < nodes.size(); ++x) {
            for (int y = x + 1; y < nodes.size(); ++y) {
                for (int z = 0; z < nodes.size(); ++z) {
                    if (z == x || z == y) continue;
                    Node X = nodes.get(x);
                    Node Y = nodes.get(y);
                    Node Z = nodes.get(z);
                    HashSet<Node> cond = new HashSet<Node>();
                    cond.add(Z);
                    boolean msep = msepTest.checkIndependence(X, Y, cond).isIndependent();
                    boolean marginallyIndependent = test.checkIndependence(X, Y, cond).isIndependent();
                    if (!msep && !marginallyIndependent) {
                        ++tp;
                    }
                    if (msep && marginallyIndependent) {
                        ++tn;
                    }
                    if (marginallyIndependent && !msep) {
                        ++fn;
                    }
                    if (msep && !marginallyIndependent) {
                        ++fp;
                    }
                    System.out.println(X + " _||_ " + Y + " | " + Z + " " + (msep ? "D-DESEPARATED" : "d-connected") + " " + (marginallyIndependent ? "INDEPENDENT" : "dependent"));
                }
            }
        }
        System.out.println("TP = " + tp);
        System.out.println("TN = " + tn);
        System.out.println("FP = " + fp);
        System.out.println("FN = " + fn);
        System.out.println("Precision = " + (double)tp / (double)(tp + fp));
        System.out.println("Recall = " + (double)tp / (double)(tp + fn));
    }

    private static DataSet getUniformTransform(DataSet dataSet) {
        try {
            Matrix data = dataSet.getDoubleData();
            Matrix X = data.like();
            double N = dataSet.getNumRows();
            for (int j = 0; j < data.getNumColumns(); ++j) {
                double[] x1 = Arrays.copyOf(data.getColumn(j).toArray(), data.getNumRows());
                if (dataSet.getVariable(j) instanceof DiscreteVariable) {
                    X.assignColumn(j, new Vector(x1));
                    continue;
                }
                double[] xTransformed = DataUtils.ranks(x1);
                int i = 0;
                while (i < xTransformed.length) {
                    int n = i;
                    xTransformed[n] = xTransformed[n] - 1.0;
                    int n2 = i++;
                    xTransformed[n2] = xTransformed[n2] / N;
                }
                X.assignColumn(j, new Vector(xTransformed));
            }
            return new BoxDataSet(new VerticalDoubleDataBox(X.transpose().toArray()), dataSet.getVariables());
        }
        catch (OutOfRangeException e) {
            e.printStackTrace();
            return dataSet;
        }
    }

    private static double getConditionallyIndependentUniformPvalue(double[][] data, int x, int y, int[] z, int m, int numCondCategories) {
        double chiSquare = 0.0;
        int[][][] dataCounts = IndTestUniformScatter.countConditionalDataOnGrid(data, x, y, z, m, numCondCategories);
        int dof = 0;
        int N = data[0].length;
        double expectedCount = (double)N / (double)(m * m * dataCounts.length);
        int[][][] nArray = dataCounts;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int[][] dataCount;
            int[][] nArray2 = dataCount = nArray[i];
            int n2 = nArray2.length;
            for (int j = 0; j < n2; ++j) {
                int[] row;
                int[] nArray3 = row = nArray2[j];
                int n3 = nArray3.length;
                for (int k = 0; k < n3; ++k) {
                    double observedCount = nArray3[k];
                    double contribution = (observedCount - expectedCount) * (observedCount - expectedCount) / expectedCount;
                    chiSquare += contribution;
                    ++dof;
                }
            }
        }
        if (--dof < 1) {
            return Double.NaN;
        }
        ChiSquaredDistribution chiSquaredDistribution = new ChiSquaredDistribution(dof);
        return 1.0 - chiSquaredDistribution.cumulativeProbability(chiSquare);
    }

    private static double count(int[][] counts) {
        int count = 0;
        for (int[] doubles : counts) {
            for (int j = 0; j < counts[0].length; ++j) {
                count += doubles[j];
            }
        }
        return count;
    }

    private static boolean isMarginallyIndependentUniform(double[][] data, int x, int y, double avgCountPerCell, double alpha) {
        int N = data[0].length;
        double numCells = (double)N / avgCountPerCell;
        int m = (int)Math.pow(numCells, 0.5);
        double chiSquare = 0.0;
        double[][] dataCounts = IndTestUniformScatter.countDataOnGrid(data, x, y, m, m);
        for (int k = 0; k < m; ++k) {
            for (int l = 0; l < m; ++l) {
                double expectedCount = (double)N / (double)(m * m);
                double observedCount = dataCounts[k][l];
                double contribution = (observedCount - expectedCount) * (observedCount - expectedCount) / expectedCount;
                chiSquare += contribution;
            }
        }
        int degreesOfFreedom = m * m - 1;
        ChiSquaredDistribution chiSquaredDistribution = new ChiSquaredDistribution(degreesOfFreedom);
        double criticalValue = chiSquaredDistribution.inverseCumulativeProbability(1.0 - alpha);
        return chiSquare <= criticalValue;
    }

    private static double[][] countDataOnGrid(double[][] data, int x, int y, int m1, int m2) {
        double[] xData = data[x];
        double[] yData = data[y];
        double[][] dataCounts = new double[m1][m2];
        for (int i = 0; i < xData.length; ++i) {
            int row = (int)(xData[i] * (double)m1);
            int column = (int)(yData[i] * (double)m2);
            double[] dArray = dataCounts[row];
            int n = column;
            dArray[n] = dArray[n] + 1.0;
        }
        return dataCounts;
    }

    private static int[][][] countConditionalDataOnGrid(double[][] data, int x, int y, int[] z, int m, int numCondCategories) {
        int[] dimensions = new int[z.length];
        Arrays.fill(dimensions, numCondCategories);
        int[][][] dataCounts = new int[(int)Math.pow(numCondCategories, z.length)][m][m];
        int slice = -1;
        CombinationIterator combinationIterator = new CombinationIterator(dimensions);
        while (combinationIterator.hasNext()) {
            int[] combination = combinationIterator.next();
            dataCounts[++slice] = new int[m][m];
            boolean hasInSlice = false;
            for (int i = 0; i < data[x].length; ++i) {
                boolean inSlice = true;
                if (m == 0) {
                    inSlice = false;
                } else {
                    for (int j = 0; j < combination.length; ++j) {
                        if ((int)(data[z[j]][i] * (double)numCondCategories) == combination[j]) continue;
                        inSlice = false;
                        break;
                    }
                }
                if (!inSlice) continue;
                int row = (int)(data[x][i] * (double)m);
                int column = (int)(data[y][i] * (double)m);
                int[] nArray = dataCounts[slice][row];
                int n = column;
                nArray[n] = nArray[n] + 1;
                hasInSlice = true;
            }
            if (hasInSlice) continue;
            dataCounts[slice] = new int[0][0];
        }
        return dataCounts;
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, Set<Node> z) {
        List<Node> nodes = this.transformed.getVariables();
        ArrayList<Node> zz = new ArrayList<Node>(z);
        int xIndex = nodes.indexOf(x);
        int yIndex = nodes.indexOf(y);
        int[] _z = new int[z.size()];
        for (int i = 0; i < z.size(); ++i) {
            _z[i] = nodes.indexOf(zz.get(i));
        }
        int N = this.data[0].length;
        double numCells = (double)N / this.avgCountPerCell;
        double numCellsPerTable = numCells / Math.pow(this.numCondCategories, z.size());
        int m = (int)Math.pow(numCellsPerTable, 0.5);
        double p = IndTestUniformScatter.getConditionallyIndependentUniformPvalue(this.data, xIndex, yIndex, _z, m, this.numCondCategories);
        if (Double.isNaN(p)) {
            throw new RuntimeException("Undefined p-value encountered when testing " + LogUtilsSearch.independenceFact(x, y, z));
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), p > this.alpha, p, p);
    }

    @Override
    public List<Node> getVariables() {
        return this.dataSet.getVariables();
    }

    @Override
    public DataModel getData() {
        return this.dataSet;
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }
}

