/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.ConditioningSetType;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.test.MsepTest;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.UniformityTest;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

public class MarkovCheck {
    private final Graph graph;
    private final IndependenceTest independenceTest;
    private final MsepTest msep;
    private final List<IndependenceResult> resultsIndep = new ArrayList<IndependenceResult>();
    private final List<IndependenceResult> resultsDep = new ArrayList<IndependenceResult>();
    private ConditioningSetType setType;
    private boolean parallelized = false;
    private double fractionDependentIndep = Double.NaN;
    private double fractionDependentDep = Double.NaN;
    private double ksPValueIndep = Double.NaN;
    private double ksPValueDep = Double.NaN;
    private double bernoulliPIndep = Double.NaN;
    private double bernoulliPDep = Double.NaN;
    private int numResamples = 1;

    public MarkovCheck(Graph graph, IndependenceTest independenceTest, ConditioningSetType setType) {
        this.graph = GraphUtils.replaceNodes(graph, independenceTest.getVariables());
        this.independenceTest = independenceTest;
        this.msep = new MsepTest(this.graph);
        this.setType = setType;
    }

    @NotNull
    public static AllSubsetsIndependenceFacts getAllSubsetsIndependenceFacts(Graph graph) {
        ArrayList<Node> variables = new ArrayList<Node>(graph.getNodes());
        MsepTest msepTest = new MsepTest(graph);
        ArrayList<Node> nodes = new ArrayList<Node>(variables);
        Collections.sort(nodes);
        ArrayList<IndependenceFact> msep = new ArrayList<IndependenceFact>();
        ArrayList<IndependenceFact> mconn = new ArrayList<IndependenceFact>();
        for (Node x : nodes) {
            ArrayList<Node> other = new ArrayList<Node>(variables);
            Collections.sort(other);
            other.remove(x);
            for (Node y : other) {
                int[] list;
                ArrayList<Node> _other = new ArrayList<Node>(other);
                _other.remove(y);
                SublistGenerator generator = new SublistGenerator(_other.size(), _other.size());
                while ((list = generator.next()) != null) {
                    Set<Node> z = GraphUtils.asSet(list, _other);
                    if (msepTest.isMSeparated(x, y, z)) {
                        msep.add(new IndependenceFact(x, y, z));
                        continue;
                    }
                    mconn.add(new IndependenceFact(x, y, z));
                }
            }
        }
        return new AllSubsetsIndependenceFacts(msep, mconn);
    }

    public void generateResults() {
        this.resultsIndep.clear();
        this.resultsDep.clear();
        if (this.setType == ConditioningSetType.GLOBAL_MARKOV) {
            AllSubsetsIndependenceFacts result = MarkovCheck.getAllSubsetsIndependenceFacts(this.graph);
            this.generateResultsAllSubsets(true, result.msep, result.mconn);
            this.generateResultsAllSubsets(false, result.msep, result.mconn);
        } else {
            List<Node> variables = this.independenceTest.getVariables();
            ArrayList<Node> nodes = new ArrayList<Node>(variables);
            Collections.sort(nodes);
            List<Node> order = this.graph.paths().getValidOrder(this.graph.getNodes(), true);
            for (Node x : nodes) {
                Set<Node> z;
                switch (this.setType) {
                    case LOCAL_MARKOV: {
                        z = new HashSet<Node>(this.graph.getParents(x));
                        break;
                    }
                    case ORDERED_LOCAL_MARKOV: {
                        if (order == null) {
                            throw new IllegalArgumentException("No valid order found.");
                        }
                        z = new HashSet<Node>(this.graph.getParents(x));
                        for (Node w : new ArrayList<Node>(z)) {
                            int i1 = order.indexOf(x);
                            int i2 = order.indexOf(w);
                            if (i2 < i1) continue;
                            z.remove(w);
                        }
                        break;
                    }
                    case MARKOV_BLANKET: {
                        z = GraphUtils.markovBlanket(x, this.graph);
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Unknown separation set type: " + (Object)((Object)this.setType));
                    }
                }
                HashSet<Node> msep = new HashSet<Node>();
                HashSet<Node> mconn = new HashSet<Node>();
                ArrayList<Node> other = new ArrayList<Node>(this.graph.getNodes());
                Collections.sort(other);
                other.removeAll(z);
                for (Node y : other) {
                    if (y == x || z.contains(x) || z.contains(y)) continue;
                    if (this.msep.isMSeparated(x, y, z)) {
                        msep.add(y);
                        continue;
                    }
                    mconn.add(y);
                }
                this.generateResults(true, x, z, msep, mconn);
                this.generateResults(false, x, z, msep, mconn);
            }
        }
        this.calcStats(true);
        this.calcStats(false);
    }

    public ConditioningSetType getSetType() {
        return this.setType;
    }

    public void setSetType(ConditioningSetType setType) {
        this.setType = setType;
    }

    public void setParallelized(boolean parallelized) {
        this.parallelized = parallelized;
    }

    public List<IndependenceResult> getResults(boolean indep) {
        if (indep) {
            return new ArrayList<IndependenceResult>(this.resultsIndep);
        }
        return new ArrayList<IndependenceResult>(this.resultsDep);
    }

    public List<Double> getPValues(List<IndependenceResult> results) {
        ArrayList<Double> pValues = new ArrayList<Double>();
        for (IndependenceResult result : results) {
            pValues.add(result.getPValue());
        }
        return pValues;
    }

    public double getFractionDependent(boolean indep) {
        if (indep) {
            return this.fractionDependentIndep;
        }
        return this.fractionDependentDep;
    }

    public double getKsPValue(boolean indep) {
        if (indep) {
            return this.ksPValueIndep;
        }
        return this.ksPValueDep;
    }

    public double getBernoulliPValue(boolean indep) {
        if (indep) {
            return this.bernoulliPIndep;
        }
        return this.bernoulliPDep;
    }

    public List<Node> getVariables() {
        return new ArrayList<Node>(this.independenceTest.getVariables());
    }

    public Node getVariable(String name) {
        return this.independenceTest.getVariable(name);
    }

    public IndependenceTest getIndependenceTest() {
        return this.independenceTest;
    }

    private void generateResults(boolean indep, Node x, Set<Node> z, Set<Node> msep, Set<Node> mconn) {
        ArrayList<IndependenceFact> facts = new ArrayList<IndependenceFact>();
        if (indep) {
            for (Node y : msep) {
                if (z.contains(y)) continue;
                facts.add(new IndependenceFact(x, y, z));
            }
        } else {
            for (Node y : mconn) {
                if (z.contains(y)) continue;
                facts.add(new IndependenceFact(x, y, z));
            }
        }
        class IndCheckTask
        implements Callable<List<IndependenceResult>> {
            private final int from;
            private final int to;
            private final List<IndependenceFact> facts;
            private final IndependenceTest independenceTest;

            IndCheckTask(int from, int to, List<IndependenceFact> facts, IndependenceTest test) {
                this.from = from;
                this.to = to;
                this.facts = facts;
                this.independenceTest = test;
            }

            @Override
            public List<IndependenceResult> call() {
                ArrayList<IndependenceResult> results = new ArrayList<IndependenceResult>();
                for (int i = this.from; i < this.to && !Thread.interrupted(); ++i) {
                    IndependenceResult result;
                    IndependenceFact fact = this.facts.get(i);
                    Node x = fact.getX();
                    Node y = fact.getY();
                    Set<Node> z = fact.getZ();
                    boolean verbose = this.independenceTest.isVerbose();
                    this.independenceTest.setVerbose(false);
                    try {
                        result = this.independenceTest.checkIndependence(x, y, z);
                    }
                    catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                    boolean indep = result.isIndependent();
                    double pValue = result.getPValue();
                    this.independenceTest.setVerbose(verbose);
                    if (Double.isNaN(pValue)) continue;
                    results.add(new IndependenceResult(fact, indep, pValue, Double.NaN));
                }
                return results;
            }
        }
        ArrayList<IndCheckTask> tasks = new ArrayList<IndCheckTask>();
        int chunkSize = this.getChunkSize(facts.size());
        for (int i = 0; i < facts.size() && !Thread.currentThread().isInterrupted(); i += chunkSize) {
            IndCheckTask task = new IndCheckTask(i, FastMath.min(facts.size(), i + chunkSize), facts, this.independenceTest);
            if (!this.parallelized) {
                Object _results = task.call();
                this.getResultsLocal(indep).addAll((Collection<IndependenceResult>)_results);
                continue;
            }
            tasks.add(task);
        }
        if (this.parallelized) {
            List theseResults = ForkJoinPool.commonPool().invokeAll(tasks);
            for (Future future : theseResults) {
                try {
                    this.getResultsLocal(indep).addAll((Collection)future.get());
                }
                catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    private List<Integer> getSubsampleRows(double v) {
        int sampleSize = this.independenceTest.getSampleSize();
        int subsampleSize = (int)FastMath.ceil((double)sampleSize * v);
        ArrayList<Integer> rows = new ArrayList<Integer>(sampleSize);
        for (int i = 0; i < sampleSize; ++i) {
            rows.add(i);
        }
        Collections.shuffle(rows);
        return rows.subList(0, subsampleSize);
    }

    private List<Integer> getBoostrapSample(double v) {
        int sampleSize = this.independenceTest.getSampleSize();
        int subsampleSize = (int)FastMath.floor((double)sampleSize * v);
        ArrayList<Integer> rows = new ArrayList<Integer>(sampleSize);
        for (int i = 0; i < subsampleSize; ++i) {
            rows.add(RandomUtil.getInstance().nextInt(sampleSize));
        }
        return rows;
    }

    private void generateResultsAllSubsets(boolean indep, List<IndependenceFact> msep, List<IndependenceFact> mconn) {
        List<IndependenceFact> facts = indep ? msep : mconn;
        class IndCheckTask
        implements Callable<List<IndependenceResult>> {
            private final int from;
            private final int to;
            private final List<IndependenceFact> facts;
            private final IndependenceTest independenceTest;

            IndCheckTask(int from, int to, List<IndependenceFact> facts, IndependenceTest test) {
                this.from = from;
                this.to = to;
                this.facts = facts;
                this.independenceTest = test;
            }

            @Override
            public List<IndependenceResult> call() {
                ArrayList<IndependenceResult> results = new ArrayList<IndependenceResult>();
                for (int i = this.from; i < this.to && !Thread.interrupted(); ++i) {
                    IndependenceResult result;
                    IndependenceFact fact = this.facts.get(i);
                    Node x = fact.getX();
                    Node y = fact.getY();
                    Set<Node> z = fact.getZ();
                    boolean verbose = this.independenceTest.isVerbose();
                    this.independenceTest.setVerbose(false);
                    try {
                        result = this.independenceTest.checkIndependence(x, y, z);
                    }
                    catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                    boolean indep = result.isIndependent();
                    double pValue = result.getPValue();
                    this.independenceTest.setVerbose(verbose);
                    if (Double.isNaN(pValue)) continue;
                    results.add(new IndependenceResult(fact, indep, pValue, Double.NaN));
                }
                return results;
            }
        }
        ArrayList<IndCheckTask> tasks = new ArrayList<IndCheckTask>();
        int chunkSize = this.getChunkSize(facts.size());
        for (int i = 0; i < facts.size() && !Thread.currentThread().isInterrupted(); i += chunkSize) {
            IndCheckTask task = new IndCheckTask(i, FastMath.min(facts.size(), i + chunkSize), facts, this.independenceTest);
            if (!this.parallelized) {
                Object _results = task.call();
                this.getResultsLocal(indep).addAll((Collection<IndependenceResult>)_results);
                continue;
            }
            tasks.add(task);
        }
        if (this.parallelized) {
            List theseResults = ForkJoinPool.commonPool().invokeAll(tasks);
            for (Future future : theseResults) {
                try {
                    this.getResultsLocal(indep).addAll((Collection)future.get());
                }
                catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    private void calcStats(boolean indep) {
        ArrayList<IndependenceResult> results = new ArrayList<IndependenceResult>(this.getResultsLocal(indep));
        int dependent = 0;
        for (IndependenceResult result : results) {
            if (!result.isDependent() || Double.isNaN(result.getPValue())) continue;
            ++dependent;
        }
        if (indep) {
            this.fractionDependentIndep = (double)dependent / (double)results.size();
        } else {
            this.fractionDependentDep = (double)dependent / (double)results.size();
        }
        List<Double> pValues = this.getPValues(results);
        if (indep) {
            if (pValues.size() < 2) {
                this.ksPValueIndep = Double.NaN;
                this.bernoulliPIndep = Double.NaN;
            } else {
                this.ksPValueIndep = UniformityTest.getPValue(pValues, 0.0, 1.0);
                this.bernoulliPIndep = this.getBernoulliP(pValues, this.independenceTest.getAlpha());
            }
        } else if (pValues.size() < 2) {
            this.ksPValueDep = Double.NaN;
            this.bernoulliPDep = Double.NaN;
        } else {
            this.ksPValueDep = UniformityTest.getPValue(pValues, 0.0, 1.0);
            this.bernoulliPDep = this.getBernoulliP(pValues, this.independenceTest.getAlpha());
        }
    }

    private double getBernoulliP(List<Double> pValues, double alpha) {
        int dependentJudgments = 0;
        for (double pValue : pValues) {
            if (!(pValue < alpha)) continue;
            ++dependentJudgments;
        }
        int n = pValues.size();
        BinomialDistribution bd = new BinomialDistribution(n, alpha);
        return (1.0 - bd.cumulativeProbability(dependentJudgments)) / 2.0 + bd.probability(n - dependentJudgments) / 2.0;
    }

    private int getChunkSize(int n) {
        int chunk = (int)FastMath.ceil((double)n / (double)(5 * Runtime.getRuntime().availableProcessors()));
        if (chunk < 1) {
            chunk = 1;
        }
        return chunk;
    }

    private List<IndependenceResult> getResultsLocal(boolean indep) {
        if (indep) {
            return this.resultsIndep;
        }
        return this.resultsDep;
    }

    public int getNumResamples() {
        return this.numResamples;
    }

    public void setNumResamples(int numResamples) {
        this.numResamples = numResamples;
    }

    public static class AllSubsetsIndependenceFacts {
        private final List<IndependenceFact> msep;
        private final List<IndependenceFact> mconn;

        public AllSubsetsIndependenceFacts(List<IndependenceFact> msep, List<IndependenceFact> mconn) {
            this.msep = msep;
            this.mconn = mconn;
        }

        public String toStringIndep() {
            StringBuilder builder = new StringBuilder("All subsets independence facts:\n");
            for (IndependenceFact fact : this.msep) {
                builder.append(fact).append("\n");
            }
            return builder.toString();
        }

        public String toStringDep() {
            StringBuilder builder = new StringBuilder("All subsets independence facts:\n");
            for (IndependenceFact fact : this.mconn) {
                builder.append(fact).append("\n");
            }
            return builder.toString();
        }

        public List<IndependenceFact> getMsep() {
            return this.msep;
        }

        public List<IndependenceFact> getMconn() {
            return this.mconn;
        }
    }
}

