/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.utils;

import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.utils.Tetrad;
import edu.cmu.tetrad.util.Matrix;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;

public class DeltaTetradTest {
    private final int N;
    private final ICovarianceMatrix cov;
    private final List<Node> variables;
    private final Map<Node, Integer> variablesHash;
    private DataSet dataSet;
    private double[][] data;
    private int df;
    private double chisq;

    public DeltaTetradTest(DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        if (!dataSet.isContinuous()) {
            throw new IllegalArgumentException();
        }
        this.cov = new CovarianceMatrix(dataSet);
        ArrayList<DataSet> data1 = new ArrayList<DataSet>();
        data1.add(dataSet);
        List<DataSet> data2 = DataTransforms.center(data1);
        this.dataSet = data2.get(0);
        this.data = this.dataSet.getDoubleData().transpose().toArray();
        this.N = dataSet.getNumRows();
        this.variables = dataSet.getVariables();
        this.variablesHash = new HashMap<Node, Integer>();
        for (int i = 0; i < this.variables.size(); ++i) {
            this.variablesHash.put(this.variables.get(i), i);
        }
    }

    public DeltaTetradTest(ICovarianceMatrix cov) {
        if (cov == null) {
            throw new NullPointerException();
        }
        this.cov = cov;
        this.N = cov.getSampleSize();
        this.variables = cov.getVariables();
        this.variablesHash = new HashMap<Node, Integer>();
        for (int i = 0; i < this.variables.size(); ++i) {
            this.variablesHash.put(this.variables.get(i), i);
        }
    }

    public double calcChiSquare(Tetrad ... tetrads) {
        double chisq;
        Node h;
        Node g;
        Node f;
        Node e;
        this.df = tetrads.length;
        LinkedHashSet<Sigma> boldSigmaSet = new LinkedHashSet<Sigma>();
        for (Tetrad tetrad : tetrads) {
            boldSigmaSet.add(new Sigma(tetrad.getI(), tetrad.getK()));
            boldSigmaSet.add(new Sigma(tetrad.getI(), tetrad.getL()));
            boldSigmaSet.add(new Sigma(tetrad.getJ(), tetrad.getK()));
            boldSigmaSet.add(new Sigma(tetrad.getJ(), tetrad.getL()));
        }
        ArrayList boldSigma = new ArrayList(boldSigmaSet);
        Matrix sigma_ss = new Matrix(boldSigma.size(), boldSigma.size());
        for (int i = 0; i < boldSigma.size(); ++i) {
            for (int j = 0; j < boldSigma.size(); ++j) {
                double _ss;
                Sigma sigmaef = (Sigma)boldSigma.get(i);
                Sigma sigmagh = (Sigma)boldSigma.get(j);
                e = sigmaef.getA();
                f = sigmaef.getB();
                g = sigmagh.getA();
                h = sigmagh.getB();
                if (this.cov != null && this.cov instanceof CorrelationMatrix) {
                    double rr = 0.5 * (this.sxy(e, f) * this.sxy(g, h)) * (this.sxy(e, g) * this.sxy(e, g) + this.sxy(e, h) * this.sxy(e, h) + this.sxy(f, g) * this.sxy(f, g) + this.sxy(f, h) * this.sxy(f, h)) + this.sxy(e, g) * this.sxy(f, h) + this.sxy(e, h) * this.sxy(f, g) - this.sxy(e, f) * (this.sxy(f, g) * this.sxy(f, h) + this.sxy(e, g) * this.sxy(e, h)) - this.sxy(g, h) * (this.sxy(f, g) * this.sxy(e, g) + this.sxy(f, h) * this.sxy(e, h));
                    sigma_ss.set(i, j, rr);
                    continue;
                }
                if (this.cov != null && this.dataSet == null) {
                    _ss = this.sxy(e, g) * this.sxy(f, h) - this.sxy(e, h) * this.sxy(f, g);
                    sigma_ss.set(i, j, _ss);
                    continue;
                }
                _ss = this.sxyzw(e, f, g, h) - this.sxy(e, f) * this.sxy(g, h);
                sigma_ss.set(i, j, _ss);
            }
        }
        Matrix del = new Matrix(boldSigma.size(), tetrads.length);
        for (int i = 0; i < boldSigma.size(); ++i) {
            for (int j = 0; j < tetrads.length; ++j) {
                Sigma sigma = (Sigma)boldSigma.get(i);
                Tetrad tetrad = tetrads[j];
                Node e2 = tetrad.getI();
                Node f2 = tetrad.getJ();
                Node g2 = tetrad.getK();
                Node h2 = tetrad.getL();
                double derivative = this.getDerivative(e2, f2, g2, h2, sigma.getA(), sigma.getB());
                del.set(i, j, derivative);
            }
        }
        Matrix t = new Matrix(tetrads.length, 1);
        for (int i = 0; i < tetrads.length; ++i) {
            Tetrad tetrad = tetrads[i];
            e = tetrad.getI();
            f = tetrad.getJ();
            g = tetrad.getK();
            h = tetrad.getL();
            double d1 = this.sxy(e, f);
            double d2 = this.sxy(g, h);
            double d3 = this.sxy(e, g);
            double d4 = this.sxy(f, h);
            double value = d1 * d2 - d3 * d4;
            t.set(i, 0, value);
        }
        Matrix w1 = del.transpose().times(sigma_ss);
        Matrix sigma_tt = w1.times(del);
        Matrix v0 = sigma_tt.inverse();
        Matrix v1 = t.transpose().times(v0);
        Matrix v2 = v1.times(t);
        this.chisq = chisq = (double)this.N * v2.get(0, 0);
        return chisq;
    }

    public double getPValue() {
        double cdf = new ChiSquaredDistribution(this.df).cumulativeProbability(this.chisq);
        return 1.0 - cdf;
    }

    public double getPValue(Tetrad ... tetrads) {
        this.calcChiSquare(tetrads);
        return this.getPValue();
    }

    private double sxyzw(Node e, Node f, Node g, Node h) {
        if (this.dataSet == null) {
            throw new IllegalArgumentException("To calculate sxyzw, tabular data is needed.");
        }
        int x = this.variablesHash.get(e);
        int y = this.variablesHash.get(f);
        int z = this.variablesHash.get(g);
        int w = this.variablesHash.get(h);
        return this.getForthMoment(x, y, z, w);
    }

    private double getForthMoment(int x, int y, int z, int w) {
        return this.sxyzw(x, y, z, w);
    }

    private double sxy(Node _node1, Node _node2) {
        int i = this.variablesHash.get(_node1);
        int j = this.variablesHash.get(_node2);
        if (this.cov != null) {
            return this.cov.getValue(i, j);
        }
        double[] arr1 = this.data[i];
        double[] arr2 = this.data[j];
        return this.sxy(arr1, arr2, arr1.length);
    }

    private double getDerivative(Node node1, Node node2, Node node3, Node node4, Node a, Node b) {
        if (node1 == a && node2 == b) {
            return this.sxy(node3, node4);
        }
        if (node1 == b && node2 == a) {
            return this.sxy(node3, node4);
        }
        if (node3 == a && node4 == b) {
            return this.sxy(node1, node2);
        }
        if (node3 == b && node4 == a) {
            return this.sxy(node1, node2);
        }
        if (node1 == a && node3 == b) {
            return -this.sxy(node2, node4);
        }
        if (node1 == b && node3 == a) {
            return -this.sxy(node2, node4);
        }
        if (node2 == a && node4 == b) {
            return -this.sxy(node1, node3);
        }
        if (node2 == b && node4 == a) {
            return -this.sxy(node1, node3);
        }
        return 0.0;
    }

    private double sxyzw(int x, int y, int z, int w) {
        double sxyzw = 0.0;
        double[] _x = this.data[x];
        double[] _y = this.data[y];
        double[] _z = this.data[z];
        double[] _w = this.data[w];
        int N = _x.length;
        for (int j = 0; j < N; ++j) {
            sxyzw += _x[j] * _y[j] * _z[j] * _w[j];
        }
        return 1.0 / (double)N * sxyzw;
    }

    private double sxy(double[] array1, double[] array2, int N) {
        double sum = 0.0;
        for (int i = 0; i < N; ++i) {
            sum += array1[i] * array2[i];
        }
        return 1.0 / (double)N * sum;
    }

    private static class Sigma {
        private final Node a;
        private final Node b;

        public Sigma(Node a, Node b) {
            this.a = a;
            this.b = b;
        }

        public Node getA() {
            return this.a;
        }

        public Node getB() {
            return this.b;
        }

        public boolean equals(Object o) {
            if (!(o instanceof Sigma)) {
                throw new IllegalArgumentException();
            }
            Sigma _o = (Sigma)o;
            return _o.getA().equals(this.getA()) && _o.getB().equals(this.getB()) || _o.getB().equals(this.getA()) && _o.getA().equals(this.getB());
        }

        public int hashCode() {
            return this.a.hashCode() + this.b.hashCode();
        }

        public String toString() {
            return "Sigma(" + this.getA() + ", " + this.getB() + ")";
        }
    }
}

