/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.mb;

import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.MbSearch;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.DepthChoiceGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class VanderbiltHitonMb
implements MbSearch {
    private boolean symmetric = false;
    private IndependenceTest independenceTest;
    private List<Node> variables;
    private List<Node> sortedVariables;
    private int depth;
    private int numIndTests = 0;
    Map<Node, List<Node>> pc;
    private Set<Node> trimmed;

    public VanderbiltHitonMb(IndependenceTest test, int depth, boolean symmetric) {
        if (test == null) {
            throw new NullPointerException();
        }
        this.independenceTest = test;
        this.variables = test.getVariables();
        this.depth = depth;
        this.symmetric = symmetric;
    }

    @Override
    public List<Node> findMb(String targetName) {
        TetradLogger.getInstance().log("info", "target = " + targetName);
        this.numIndTests = 0;
        long time = System.currentTimeMillis();
        this.pc = new HashMap<Node, List<Node>>();
        this.trimmed = new HashSet<Node>();
        final Node t = this.getVariableForName(targetName);
        this.sortedVariables = new LinkedList<Node>(this.variables);
        Collections.sort(this.sortedVariables, new Comparator<Node>(){

            @Override
            public int compare(Node o1, Node o2) {
                double score2;
                double score1 = o1 == t ? 1.0 : VanderbiltHitonMb.this.association(o1, t);
                double d = score2 = o2 == t ? 1.0 : VanderbiltHitonMb.this.association(o2, t);
                if (score1 < score2) {
                    return 1;
                }
                if (score1 > score2) {
                    return -1;
                }
                return 0;
            }
        });
        List<Node> nodes = this.hitonMb(t);
        long time2 = System.currentTimeMillis() - time;
        TetradLogger.getInstance().log("info", "Number of seconds: " + (double)time2 / 1000.0);
        TetradLogger.getInstance().log("info", "Number of independence tests performed: " + this.numIndTests);
        return nodes;
    }

    private List<Node> hitonMb(Node t) {
        HashSet<Node> mb = new HashSet<Node>();
        HashSet<Node> _pcpc = new HashSet<Node>();
        for (Node node : this.getPc(t)) {
            List<Node> f = this.getPc(node);
            this.pc.put(node, f);
            _pcpc.addAll(f);
        }
        LinkedList pcpc = new LinkedList(_pcpc);
        HashSet<Node> currentMb = new HashSet<Node>(this.getPc(t));
        currentMb.addAll(pcpc);
        currentMb.remove(t);
        HashSet<Node> diff = new HashSet<Node>(currentMb);
        diff.removeAll(this.getPc(t));
        diff.remove(t);
        block1: for (Node x : diff) {
            int[] choice;
            LinkedList<Node> s = null;
            DepthChoiceGenerator generator = new DepthChoiceGenerator(pcpc.size(), this.depth);
            while ((choice = generator.next()) != null) {
                LinkedList<Node> _s = new LinkedList<Node>();
                for (int index : choice) {
                    _s.add((Node)pcpc.get(index));
                }
                ++this.numIndTests;
                if (!this.independenceTest.isIndependent(t, x, _s)) continue;
                s = _s;
                break;
            }
            if (s == null) {
                System.out.println("S not found.");
                continue;
            }
            HashSet<Node> ySet = new HashSet<Node>();
            for (Node y : this.getPc(t)) {
                if (!this.pc.get(y).contains(x)) continue;
                ySet.add(y);
            }
            for (Node y : ySet) {
                if (x == y) continue;
                LinkedList<Node> _s = new LinkedList<Node>(s);
                _s.add(y);
                ++this.numIndTests;
                if (this.independenceTest.isIndependent(t, x, _s)) continue;
                mb.add(x);
                continue block1;
            }
        }
        mb.addAll(this.getPc(t));
        return new LinkedList<Node>(mb);
    }

    private List<Node> hitonPc(Node t) {
        LinkedList<Node> variables = new LinkedList<Node>(this.sortedVariables);
        variables.remove(t);
        ArrayList<Node> cpc = new ArrayList<Node>();
        while (!variables.isEmpty()) {
            Node vi = variables.removeFirst();
            cpc.add(vi);
            block1: for (Node x : new LinkedList<Node>(cpc)) {
                cpc.remove(x);
                for (int d = 0; d <= Math.min(cpc.size(), this.depth); ++d) {
                    int[] choice;
                    ChoiceGenerator generator = new ChoiceGenerator(cpc.size(), d);
                    while ((choice = generator.next()) != null) {
                        LinkedList<Node> s = new LinkedList<Node>();
                        for (int index : choice) {
                            s.add((Node)cpc.get(index));
                        }
                        if (x != vi && !s.contains(vi)) continue;
                        ++this.numIndTests;
                        if (!this.independenceTest.isIndependent(x, t, s)) continue;
                        continue block1;
                    }
                }
                cpc.add(x);
            }
        }
        return cpc;
    }

    private List<Node> getPc(Node t) {
        if (!this.pc.containsKey(t)) {
            this.pc.put(t, this.hitonPc(t));
        }
        if (this.symmetric && !this.trimmed.contains(t)) {
            this.trimPc(t);
            this.trimmed.add(t);
        }
        return this.pc.get(t);
    }

    private void trimPc(Node t) {
        for (Node x : new LinkedList(this.pc.get(t))) {
            if (!this.pc.containsKey(x)) {
                this.pc.put(x, this.hitonPc(x));
            }
            if (this.pc.get(x).contains(t)) continue;
            this.pc.get(t).remove(x);
        }
    }

    private double association(Node x, Node y) {
        ++this.numIndTests;
        this.independenceTest.isIndependent(x, y, new LinkedList<Node>());
        return 1.0 - this.independenceTest.getPValue();
    }

    @Override
    public String getAlgorithmName() {
        return this.symmetric ? "HITON-MB-SYM" : "HITON-MB";
    }

    @Override
    public int getNumIndependenceTests() {
        return this.numIndTests;
    }

    private Node getVariableForName(String targetName) {
        Node target = null;
        for (Node V : this.variables) {
            if (!V.getName().equals(targetName)) continue;
            target = V;
            break;
        }
        if (target == null) {
            throw new IllegalArgumentException("Target variable not in dataset: " + targetName);
        }
        return target;
    }
}

