/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IntSextad;
import edu.cmu.tetrad.util.Matrix;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.linear.SingularMatrixException;

public class DeltaSextadTest {
    static final long serialVersionUID = 23L;
    private double[][] data;
    private final int N;
    private final ICovarianceMatrix cov;
    private final List<Node> variables;

    public DeltaSextadTest(DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        if (!dataSet.isContinuous()) {
            throw new IllegalArgumentException();
        }
        this.cov = new CovarianceMatrix(dataSet);
        Matrix centered = DataUtils.centerData(dataSet.getDoubleData());
        this.data = centered.transpose().toArray();
        this.N = dataSet.getNumRows();
        this.variables = dataSet.getVariables();
    }

    public DeltaSextadTest(ICovarianceMatrix cov) {
        if (cov == null) {
            throw new NullPointerException();
        }
        this.cov = cov;
        this.N = cov.getSampleSize();
        this.variables = cov.getVariables();
    }

    public static DeltaSextadTest serializableInstance() {
        return new DeltaSextadTest(BoxDataSet.serializableInstance());
    }

    public double getPValue(IntSextad ... sextads) {
        int df = this.dofHarman(sextads.length);
        double chisq = this.calcChiSquare(sextads);
        double cdf = new ChiSquaredDistribution(df).cumulativeProbability(chisq);
        return 1.0 - cdf;
    }

    public double calcChiSquare(IntSextad[] sextads) {
        double chisq;
        HashSet<Sigma> boldSigmaSet = new HashSet<Sigma>();
        for (IntSextad sextad : sextads) {
            List<Integer> _nodes = sextad.getNodes();
            for (int k1 = 0; k1 < 3; ++k1) {
                for (int k2 = 0; k2 < 3; ++k2) {
                    boldSigmaSet.add(new Sigma(_nodes.get(k1), _nodes.get(3 + k2)));
                }
            }
        }
        ArrayList boldSigma = new ArrayList(boldSigmaSet);
        Matrix sigma_ss = new Matrix(boldSigma.size(), boldSigma.size());
        for (int i = 0; i < boldSigma.size(); ++i) {
            for (int j = i; j < boldSigma.size(); ++j) {
                double _ss;
                Sigma sigmaef = (Sigma)boldSigma.get(i);
                Sigma sigmagh = (Sigma)boldSigma.get(j);
                int e = sigmaef.getA();
                int f = sigmaef.getB();
                int g = sigmagh.getA();
                int h = sigmagh.getB();
                if (this.cov != null && this.cov instanceof CorrelationMatrix) {
                    double rr = 0.5 * (this.r(e, f) * this.r(g, h)) * (this.r(e, g) * this.r(e, g) + this.r(e, h) * this.r(e, h) + this.r(f, g) * this.r(f, g) + this.r(f, h) * this.r(f, h)) + this.r(e, g) * this.r(f, h) + this.r(e, h) * this.r(f, g) - this.r(e, f) * (this.r(f, g) * this.r(f, h) + this.r(e, g) * this.r(e, h)) - this.r(g, h) * (this.r(f, g) * this.r(e, g) + this.r(f, h) * this.r(e, h));
                    double rr2 = this.r(e, f, g, h) + 0.25 * this.r(e, f) * this.r(g, h) * (this.r(e, e, g, g) * this.r(f, f, g, g) + this.r(e, e, h, h) + this.r(f, f, h, h)) - 0.5 * this.r(e, f) * (this.r(e, e, g, h) + this.r(f, f, g, h)) - 0.5 * this.r(g, h) * (this.r(e, f, g, g) + this.r(e, f, h, h));
                    sigma_ss.set(i, j, rr);
                    sigma_ss.set(j, i, rr);
                    continue;
                }
                if (this.cov != null && this.data == null) {
                    _ss = this.r(e, g) * this.r(f, h) + this.r(e, h) * this.r(f, g);
                    sigma_ss.set(i, j, _ss);
                    sigma_ss.set(j, i, _ss);
                    continue;
                }
                _ss = this.r(e, f, g, h) - this.r(e, f) * this.r(g, h);
                sigma_ss.set(i, j, _ss);
                sigma_ss.set(j, i, _ss);
            }
        }
        Matrix del = new Matrix(boldSigma.size(), sextads.length);
        for (int j = 0; j < sextads.length; ++j) {
            IntSextad sextad = sextads[j];
            for (int i = 0; i < boldSigma.size(); ++i) {
                Sigma sigma = (Sigma)boldSigma.get(i);
                double derivative = this.getDerivative(sextad, sigma);
                del.set(i, j, derivative);
            }
        }
        Matrix t = new Matrix(sextads.length, 1);
        for (int i = 0; i < sextads.length; ++i) {
            IntSextad sextad = sextads[i];
            List<Integer> nodes = sextad.getNodes();
            Matrix m = new Matrix(3, 3);
            for (int k1 = 0; k1 < 3; ++k1) {
                for (int k2 = 0; k2 < 3; ++k2) {
                    m.set(k1, k2, this.r(nodes.get(k1), nodes.get(3 + k2)));
                }
            }
            double det = m.det();
            t.set(i, 0, det);
        }
        Matrix sigma_tt = del.transpose().times(sigma_ss).times(del);
        try {
            chisq = (double)this.N * t.transpose().times(sigma_tt.inverse()).times(t).get(0, 0);
        }
        catch (SingularMatrixException e) {
            throw new RuntimeException("Singularity problem.", e);
        }
        return chisq;
    }

    private double r(int i, int j) {
        if (this.cov != null) {
            return this.cov.getValue(i, j);
        }
        double[] arr1 = this.data[i];
        double[] arr2 = this.data[j];
        return this.r(arr1, arr2, arr1.length);
    }

    private double getDerivative(IntSextad sextad, Sigma sigma) {
        int a = sigma.getA();
        int b = sigma.getB();
        int n1 = sextad.getI();
        int n2 = sextad.getJ();
        int n3 = sextad.getK();
        int n4 = sextad.getL();
        int n5 = sextad.getM();
        int n6 = sextad.getN();
        double x1 = this.derivative(a, b, n1, n2, n3, n4, n5, n6);
        double x2 = this.derivative(b, a, n1, n2, n3, n4, n5, n6);
        if (x1 == 0.0) {
            return x2;
        }
        if (x2 == 0.0) {
            return x1;
        }
        throw new IllegalStateException("Both nonzero at the same time: x1 = " + x1 + " x2 = " + x2);
    }

    private double derivative(int a, int b, int n1, int n2, int n3, int n4, int n5, int n6) {
        if (a == n1) {
            if (b == n4) {
                return this.r(n2, n5) * this.r(n3, n6) - this.r(n2, n6) * this.r(n3, n5);
            }
            if (b == n5) {
                return -this.r(n2, n4) * this.r(n3, n6) + this.r(n3, n4) * this.r(n2, n6);
            }
            if (b == n6) {
                return this.r(n2, n4) * this.r(n3, n5) - this.r(n3, n4) * this.r(n2, n5);
            }
        } else if (a == n2) {
            if (b == n4) {
                return this.r(n3, n5) * this.r(n1, n6) - this.r(n1, n5) * this.r(n3, n6);
            }
            if (b == n5) {
                return this.r(n1, n4) * this.r(n3, n6) - this.r(n3, n4) * this.r(n1, n6);
            }
            if (b == n6) {
                return -this.r(n1, n4) * this.r(n3, n5) + this.r(n3, n4) * this.r(n1, n5);
            }
        } else if (a == n3) {
            if (b == n4) {
                return this.r(n1, n5) * this.r(n2, n6) - this.r(n2, n5) * this.r(n1, n6);
            }
            if (b == n5) {
                return -this.r(n1, n4) * this.r(n2, n6) + this.r(n2, n4) * this.r(n1, n6);
            }
            if (b == n6) {
                return this.r(n1, n4) * this.r(n2, n5) - this.r(n2, n4) * this.r(n1, n5);
            }
        }
        return 0.0;
    }

    public List<Node> getVariables() {
        return this.variables;
    }

    private double r(int x, int y, int z, int w) {
        double sxyzw = 0.0;
        double[] _x = this.data[x];
        double[] _y = this.data[y];
        double[] _z = this.data[z];
        double[] _w = this.data[w];
        int N = _x.length;
        for (int j = 0; j < N; ++j) {
            sxyzw += _x[j] * _y[j] * _z[j] * _w[j];
        }
        return 1.0 / (double)N * sxyzw;
    }

    private double r(double[] array1, double[] array2, int N) {
        double sum = 0.0;
        for (int i = 0; i < N; ++i) {
            sum += array1[i] * array2[i];
        }
        return 1.0 / (double)N * sum;
    }

    private int dofDrton(int n) {
        int dof = (n - 2) * (n - 3) / 2 - 2;
        if (dof < 1) {
            dof = 1;
        }
        return dof;
    }

    private int dofHarman(int n) {
        int dof = n * (n - 5) / 2 + 1;
        if (dof < 1) {
            dof = 1;
        }
        return dof;
    }

    private static class Sigma {
        private final int a;
        private final int b;

        public Sigma(int a, int b) {
            this.a = a;
            this.b = b;
        }

        public int getA() {
            return this.a;
        }

        public int getB() {
            return this.b;
        }

        public boolean equals(Object o) {
            if (!(o instanceof Sigma)) {
                throw new IllegalArgumentException();
            }
            Sigma _o = (Sigma)o;
            return _o.getA() == this.getA() && _o.getB() == this.getB() || _o.getB() == this.getA() && _o.getA() == this.getB();
        }

        public int hashCode() {
            return this.a + this.b;
        }

        public String toString() {
            return "Sigma(" + this.getA() + ", " + this.getB() + ")";
        }
    }
}

