/*
 * Decompiled with CFR 0.152.
 */
package edu.pitt.dbmi.algo.bayesian.constraint.inference;

import java.util.Arrays;
import org.apache.commons.math3.util.FastMath;

public class BCCausalInference {
    private static final double PESS_VALUE = 1.0;
    private final int numberOfNodes;
    private final int numberOfCases;
    private final int maximumNodes;
    private final int maximumCases;
    private final int maximumValues;
    private final int maximumParents;
    private final int maximumCells;
    private final double[] logFactorial;
    private final int scoreFn;
    private final int[] nodeDimension;
    private final int[][] cases;

    public BCCausalInference(int[] nodeDimension, int[][] cases) {
        this.nodeDimension = nodeDimension;
        this.cases = cases;
        this.numberOfNodes = nodeDimension.length - 2;
        this.numberOfCases = cases.length - 1;
        this.maximumNodes = this.numberOfNodes;
        this.maximumCases = cases.length - 1;
        this.maximumValues = Arrays.stream(nodeDimension).max().getAsInt();
        this.maximumParents = this.maximumNodes - 2;
        this.maximumCells = this.maximumParents * this.maximumValues * this.maximumValues * this.maximumCases;
        this.logFactorial = this.computeLogFactorial(this.maximumCases, this.maximumValues);
        this.scoreFn = 1;
    }

    private static double lnXpluslnY(double lnX, double lnY) {
        double lnYminusLnX;
        if (lnY > lnX) {
            double temp = lnX;
            lnX = lnY;
            lnY = temp;
        }
        return (lnYminusLnX = lnY - lnX) < -1022.0 ? lnX : FastMath.log1p(FastMath.exp(lnYminusLnX)) + lnX;
    }

    public double probConstraint(OP constraint, int x, int y, int[] z) {
        int n;
        double probability = 0.0;
        CountsTracker countsTracker = this.createCountsTracker(z);
        int[][] parents = countsTracker.parents;
        int[] countsTree = countsTracker.countsTree;
        int[] counts = countsTracker.counts;
        double[][] scores = countsTracker.scores;
        int[] xyProducts = countsTracker.xyProducts;
        parents[x][0] = n = z[0];
        if (n >= 0) {
            System.arraycopy(z, 1, parents[x], 1, n);
        }
        double lnMarginalLikelihood_X = this.scoreNode(x, 1, countsTracker);
        parents[y][0] = n;
        if (n >= 0) {
            System.arraycopy(z, 1, parents[y], 1, n);
        }
        double lnMarginalLikelihood_Y = this.scoreNode(y, 2, countsTracker);
        double lnMarginalLikelihood_X_Y = lnMarginalLikelihood_X + lnMarginalLikelihood_Y;
        probability = this.priorIndependent(x, y, z);
        double lnPrior_X_Y = FastMath.log(probability);
        double score_X_Y = lnMarginalLikelihood_X_Y + lnPrior_X_Y;
        ++countsTracker.numOfNodes;
        int xy = countsTracker.numOfNodes;
        for (int casei = 1; casei <= this.numberOfCases; ++casei) {
            int xValue = this.cases[casei][x];
            int yValue = this.cases[casei][y];
            xyProducts[casei] = (xValue - 1) * this.nodeDimension[x] + yValue;
        }
        countsTracker.xyDim = this.nodeDimension[x] * this.nodeDimension[y];
        parents[xy][0] = n;
        if (n >= 0) {
            System.arraycopy(z, 1, parents[xy], 1, n);
        }
        double lnMarginalLikelihood_XY = this.scoreNode(xy, 3, countsTracker);
        --countsTracker.numOfNodes;
        double lnTermPrior_X_Y = FastMath.log(probability) / (double)countsTracker.numOfScores;
        double lnTermPrior_XY = FastMath.log(1.0 - FastMath.exp(lnTermPrior_X_Y));
        double scoreAll = 0.0;
        for (int i = 1; i <= countsTracker.numOfScores; ++i) {
            scoreAll += BCCausalInference.lnXpluslnY(lnTermPrior_X_Y + (scores[i][1] + scores[i][2]), lnTermPrior_XY + scores[i][3]);
        }
        double probInd = FastMath.exp(score_X_Y - scoreAll);
        probability = constraint == OP.INDEPENDENT ? probInd : 1.0 - probInd;
        return probability;
    }

    private double scoreNode(int node, int whichList, CountsTracker countsTracker) {
        int nodeDim;
        double totalScore = 0.0;
        int[][] parents = countsTracker.parents;
        int[] counts = countsTracker.counts;
        int[] countsTree = countsTracker.countsTree;
        double[][] scores = countsTracker.scores;
        int n = nodeDim = node > this.numberOfNodes ? countsTracker.xyDim : this.nodeDimension[node];
        if (parents[node][0] > 0) {
            int firstParentSize = this.nodeDimension[parents[node][1]];
            for (int i = 1; i <= firstParentSize; ++i) {
                countsTree[i] = 0;
            }
            countsTracker.countsTreePtr = firstParentSize + 1;
            countsTracker.countsPtr = 1;
        } else {
            countsTracker.countsTreePtr = 1;
            countsTracker.countsPtr = nodeDim + 1;
            for (int i = 1; i <= nodeDim; ++i) {
                counts[i] = 0;
            }
        }
        for (int casei = 1; casei <= this.numberOfCases; ++casei) {
            this.fileCase(node, casei, countsTracker);
        }
        int instancePtr = 1;
        int q = 1;
        for (int i = 1; i <= parents[node][0]; ++i) {
            q *= this.nodeDimension[parents[node][i]];
        }
        countsTracker.numOfScores = 0;
        while (instancePtr < countsTracker.countsPtr) {
            double score = this.scoreFn == 1 ? this.scoringFn1(node, instancePtr, q, countsTracker) : this.scoringFn2(node, instancePtr, countsTracker);
            ++countsTracker.numOfScores;
            scores[countsTracker.numOfScores][whichList] = score;
            totalScore += score;
            instancePtr += nodeDim;
        }
        return totalScore;
    }

    private double scoringFn1(int node, int instancePtr, double q, CountsTracker countsTracker) {
        int[] counts = countsTracker.counts;
        int Nij = 0;
        double scoreOfSum = 0.0;
        int r = node > this.numberOfNodes ? countsTracker.xyDim : this.nodeDimension[node];
        double pessDivQR = 1.0 / (q * (double)r);
        double pessDivQ = 1.0 / q;
        double lngammPessDivQR = this.gammln(pessDivQR);
        for (int k = 0; k <= r - 1; ++k) {
            int Nijk = counts[instancePtr + k];
            Nij += Nijk;
            scoreOfSum += this.gammln((double)Nijk + pessDivQR) - lngammPessDivQR;
        }
        return this.gammln(pessDivQ) - this.gammln((double)Nij + pessDivQ) + scoreOfSum;
    }

    private double scoringFn2(int node, int instancePtr, CountsTracker countsTracker) {
        int[] counts = countsTracker.counts;
        int nodeDim = node > this.numberOfNodes ? countsTracker.xyDim : this.nodeDimension[node];
        int hits = 0;
        double scoreNI = 0.0;
        for (int i = 0; i <= nodeDim - 1; ++i) {
            int count = counts[instancePtr + i];
            hits += count;
            scoreNI += this.logFactorial[count];
        }
        return scoreNI += this.logFactorial[nodeDim - 1] - this.logFactorial[hits + nodeDim - 1];
    }

    private void fileCase(int node, int casei, CountsTracker countsTracker) {
        int i;
        int nodeValue;
        int nodeDim = node > this.numberOfNodes ? countsTracker.xyDim : this.nodeDimension[node];
        int parent = 0;
        int parentValue = 0;
        int cPtr = 0;
        int parenti = 0;
        int[][] parents = countsTracker.parents;
        int[] counts = countsTracker.counts;
        int[] countsTree = countsTracker.countsTree;
        int[] xyProducts = countsTracker.xyProducts;
        int n = nodeValue = node > this.numberOfNodes ? xyProducts[casei] : this.cases[casei][node];
        if (nodeValue == 0) {
            throw new IllegalArgumentException();
        }
        int numberOfParents = parents[node][0];
        boolean missingValue = false;
        for (int i2 = 1; i2 <= numberOfParents; ++i2) {
            parent = parents[node][i2];
            int n2 = parentValue = parent > this.numberOfNodes ? xyProducts[casei] : this.cases[casei][parent];
            if (parentValue != 0) continue;
            throw new IllegalArgumentException();
        }
        int ctPtr = 1;
        int ptr = 1;
        for (i = 1; i <= numberOfParents; ++i) {
            parent = parents[node][i];
            parentValue = parent > this.numberOfNodes ? xyProducts[casei] : this.cases[casei][parent];
            ptr = countsTree[ctPtr + parentValue - 1];
            if (ptr <= 0) {
                parenti = i;
                break;
            }
            ctPtr = ptr;
        }
        if (ptr > 0) {
            cPtr = ctPtr;
        } else {
            for (i = parenti; i <= numberOfParents; ++i) {
                parent = parents[node][i];
                int n3 = parentValue = parent > this.numberOfNodes ? xyProducts[casei] : this.cases[casei][parent];
                if (i == numberOfParents) {
                    countsTree[ctPtr + parentValue - 1] = countsTracker.countsPtr;
                    continue;
                }
                countsTree[ctPtr + parentValue - 1] = countsTracker.countsTreePtr;
                for (int j = countsTracker.countsTreePtr; j <= countsTracker.countsTreePtr + this.nodeDimension[parents[node][i + 1]] - 1; ++j) {
                    countsTree[j] = 0;
                }
                ctPtr = countsTracker.countsTreePtr;
                countsTracker.countsTreePtr += this.nodeDimension[parents[node][i + 1]];
                if (countsTracker.countsPtr <= countsTracker.maxCells) continue;
                throw new IllegalArgumentException();
            }
            for (int j = countsTracker.countsPtr; j <= countsTracker.countsPtr + nodeDim - 1; ++j) {
                counts[j] = 0;
            }
            cPtr = countsTracker.countsPtr;
            countsTracker.countsPtr += nodeDim;
            if (countsTracker.countsPtr > countsTracker.maxCells) {
                throw new IllegalArgumentException();
            }
        }
        int n4 = cPtr + nodeValue - 1;
        counts[n4] = counts[n4] + 1;
    }

    private CountsTracker createCountsTracker(int[] z) {
        CountsTracker tracker = new CountsTracker();
        tracker.numOfNodes = this.numberOfNodes;
        tracker.numOfCases = this.numberOfCases;
        tracker.maxNodes = this.maximumNodes;
        tracker.maxCases = this.maximumCases;
        tracker.maxValues = this.maximumValues;
        tracker.maxParents = this.maximumParents;
        tracker.maxCells = this.maximumCells;
        if (z.length > tracker.maxParents) {
            tracker.maxParents = z.length;
            tracker.maxCells = tracker.maxParents * tracker.maxValues * tracker.maxCases;
        }
        tracker.parents = new int[tracker.maxNodes + 2][tracker.maxParents + 1];
        tracker.countsTree = new int[tracker.maxCells + 1];
        tracker.counts = new int[tracker.maxCells + 1];
        tracker.scores = new double[tracker.maxCases + 1][4];
        tracker.xyProducts = new int[tracker.numOfCases + 1];
        return tracker;
    }

    private double priorIndependent(int x, int y, int[] z) {
        return 0.5;
    }

    private double[] computeLogFactorial(int maxCases, int maxValues) {
        int size = 2 * maxCases + maxValues;
        double[] logFact = new double[size + 1];
        for (int i = 1; i < logFact.length; ++i) {
            logFact[i] = FastMath.log(i) + logFact[i - 1];
        }
        return logFact;
    }

    private double gammln(double xx) {
        if (xx == 1.0) {
            return 0.0;
        }
        if (xx > 1.0) {
            return this.gammlnCore(xx);
        }
        double z = 1.0 - xx;
        return FastMath.log(Math.PI * z) - this.gammlnCore(1.0 + z) - FastMath.log(FastMath.sin(Math.PI * z));
    }

    private double gammlnCore(double xx) {
        double stp = 2.50662827465;
        double half = 0.5;
        double one = 1.0;
        double fpf = 5.5;
        double[] cof = new double[]{0.0, 76.18009173, -86.50532033, 24.01409822, -1.231739516, 0.00120858003, -5.36382E-6};
        double x = xx - 1.0;
        double tmp = x + 5.5;
        tmp = (x + 0.5) * FastMath.log(tmp) - tmp;
        double ser = 1.0;
        for (int j = 1; j <= 6; ++j) {
            ser += cof[j] / (x += 1.0);
        }
        return tmp + FastMath.log(2.50662827465 * ser);
    }

    private static class CountsTracker {
        int numOfNodes;
        int numOfCases;
        int maxNodes;
        int maxCases;
        int maxValues;
        int maxParents;
        int maxCells;
        int numOfScores;
        int countsTreePtr;
        int countsPtr;
        int xyDim;
        int[][] parents;
        int[] countsTree;
        int[] counts;
        double[][] scores;
        int[] xyProducts;

        private CountsTracker() {
        }
    }

    public static enum OP {
        DEPENDENT,
        INDEPENDENT;

    }
}

