/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.work_in_progress;

import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IMbSearch;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

public final class Mmmb
implements IMbSearch {
    private final boolean symmetric;
    private final IndependenceTest independenceTest;
    private final List<Node> variables;
    int depth;
    Map<Node, List<Node>> pc;
    private int numIndTests;
    private Set<Node> trimmed;

    public Mmmb(IndependenceTest test, int depth, boolean symmetric) {
        if (test == null) {
            throw new NullPointerException();
        }
        if (depth < -1) {
            throw new IllegalArgumentException();
        }
        this.independenceTest = test;
        this.variables = test.getVariables();
        this.depth = depth;
        this.symmetric = symmetric;
        this.pc = new HashMap<Node, List<Node>>();
        this.trimmed = new HashSet<Node>();
    }

    @Override
    public Set<Node> findMb(Node target) {
        TetradLogger.getInstance().log("info", "target = " + target);
        this.numIndTests = 0;
        long time = MillisecondTimes.timeMillis();
        this.pc = new HashMap<Node, List<Node>>();
        this.trimmed = new HashSet<Node>();
        Set<Node> nodes = this.mmmb(target);
        long time2 = MillisecondTimes.timeMillis() - time;
        TetradLogger.getInstance().log("info", "Number of seconds: " + (double)time2 / 1000.0);
        TetradLogger.getInstance().log("info", "Number of independence tests performed: " + this.numIndTests);
        return nodes;
    }

    private Set<Node> mmmb(Node t) {
        HashSet<Node> mb = new HashSet<Node>();
        HashSet<Node> _pcpc = new HashSet<Node>();
        for (Node node : this.getPc(t)) {
            List<Node> f = this.getPc(node);
            this.pc.put(node, f);
            _pcpc.addAll(f);
        }
        LinkedList pcpc = new LinkedList(_pcpc);
        HashSet<Node> currentMb = new HashSet<Node>(this.getPc(t));
        currentMb.addAll(pcpc);
        currentMb.remove(t);
        HashSet<Node> diff = new HashSet<Node>(currentMb);
        this.getPc(t).forEach(diff::remove);
        diff.remove(t);
        block1: for (Node x : diff) {
            int[] choice;
            HashSet<Node> s = null;
            SublistGenerator generator = new SublistGenerator(pcpc.size(), this.depth);
            while ((choice = generator.next()) != null) {
                HashSet<Node> _s = new HashSet<Node>();
                for (int index : choice) {
                    _s.add((Node)pcpc.get(index));
                }
                ++this.numIndTests;
                if (!this.independenceTest.checkIndependence(t, x, _s).isIndependent()) continue;
                s = _s;
                break;
            }
            if (s == null) {
                System.out.println("S not found.");
                continue;
            }
            HashSet<Node> ySet = new HashSet<Node>();
            Object object = this.getPc(t).iterator();
            while (object.hasNext()) {
                Node y = (Node)object.next();
                if (!this.pc.get(y).contains(x)) continue;
                ySet.add(y);
            }
            for (Node y : ySet) {
                if (x == y) continue;
                HashSet<Node> _s = new HashSet<Node>(s);
                _s.add(y);
                ++this.numIndTests;
                if (this.independenceTest.checkIndependence(t, x, _s).isIndependent()) continue;
                mb.add(x);
                continue block1;
            }
        }
        mb.addAll(this.getPc(t));
        return new HashSet<Node>(mb);
    }

    private List<Node> mmpc(Node t) {
        LinkedList<Node> pc = new LinkedList<Node>();
        boolean pcIncreased = true;
        HashSet<Node> indepOfT = new HashSet<Node>();
        while (pcIncreased) {
            pcIncreased = false;
            MaxMinAssocResult ret = this.maxMinAssoc(t, pc, indepOfT);
            Node f = ret.getNode();
            Set<Node> assocSet = ret.getAssocSet();
            if (f == null) break;
            ++this.numIndTests;
            if (this.independenceTest.checkIndependence(f, t, assocSet).isIndependent()) continue;
            pcIncreased = true;
            pc.add(f);
        }
        this.backwardsConditioning(pc, t);
        TetradLogger.getInstance().log("details", "PC(" + t + ") = " + pc);
        return pc;
    }

    public List<Node> getPc(Node t) {
        if (!this.pc.containsKey(t)) {
            this.pc.put(t, this.mmpc(t));
        }
        if (this.symmetric && !this.trimmed.contains(t)) {
            this.trimPc(t);
            this.trimmed.add(t);
        }
        return this.pc.get(t);
    }

    private void trimPc(Node t) {
        for (Node x : new LinkedList(this.pc.get(t))) {
            if (!this.pc.containsKey(x)) {
                this.pc.put(x, this.mmpc(x));
            }
            if (this.pc.get(x).contains(t)) continue;
            this.pc.get(t).remove(x);
        }
    }

    private MaxMinAssocResult maxMinAssoc(Node t, List<Node> pc, Set<Node> indepOfT) {
        Node f = null;
        Set<Node> maxAssocSet = null;
        double maxAssoc = 0.0;
        for (Node v : this.variables) {
            if (t == v || pc.contains(v) || indepOfT.contains(v)) continue;
            Set<Node> minAssoc = this.minAssoc(v, t, pc);
            double assoc = this.association(v, t, minAssoc);
            if (assoc < 1.0 - this.independenceTest.getAlpha()) {
                indepOfT.add(v);
            }
            if (!(assoc > maxAssoc)) continue;
            maxAssocSet = minAssoc;
            maxAssoc = assoc;
            f = v;
        }
        return new MaxMinAssocResult(f, maxAssocSet);
    }

    private Set<Node> minAssoc(Node x, Node target, List<Node> pc) {
        int[] choice;
        double assoc = 1.0;
        HashSet<Node> set = new HashSet<Node>();
        if (pc.contains(x)) {
            throw new IllegalArgumentException();
        }
        if (pc.contains(target)) {
            throw new IllegalArgumentException();
        }
        if (x == target) {
            throw new IllegalArgumentException();
        }
        SublistGenerator generator = new SublistGenerator(pc.size(), this.depth);
        while ((choice = generator.next()) != null) {
            double _assoc;
            HashSet<Node> s = new HashSet<Node>();
            for (int index : choice) {
                s.add(pc.get(index));
            }
            if (pc.size() > 0 && !s.contains(pc.get(pc.size() - 1)) || !((_assoc = this.association(x, target, s)) < assoc)) continue;
            assoc = _assoc;
            set = s;
        }
        return set;
    }

    private void backwardsConditioning(List<Node> pc, Node target) {
        for (Node x : new LinkedList<Node>(pc)) {
            LinkedList<Node> _pc = new LinkedList<Node>(pc);
            _pc.remove(x);
            _pc.remove(target);
            Set<Node> minAssoc = this.minAssoc(x, target, _pc);
            ++this.numIndTests;
            if (!this.independenceTest.checkIndependence(x, target, minAssoc).isIndependent()) continue;
            pc.remove(x);
        }
    }

    private double association(Node x, Node target, Set<Node> s) {
        ++this.numIndTests;
        IndependenceResult result = this.independenceTest.checkIndependence(x, target, s);
        return 1.0 - result.getPValue();
    }

    @Override
    public String getAlgorithmName() {
        return this.symmetric ? "MMMB-SYM" : "MMMB";
    }

    @Override
    public int getNumIndependenceTests() {
        return this.numIndTests;
    }

    private Node getVariableForName(String targetName) {
        Node target = null;
        for (Node V : this.variables) {
            if (!V.getName().equals(targetName)) continue;
            target = V;
            break;
        }
        if (target == null) {
            throw new IllegalArgumentException("Target variable not in dataset: " + targetName);
        }
        return target;
    }

    private static class MaxMinAssocResult {
        private final Node node;
        private final Set<Node> assocSet;

        public MaxMinAssocResult(Node node, Set<Node> assocSet) {
            this.node = node;
            this.assocSet = assocSet;
        }

        public Node getNode() {
            return this.node;
        }

        public Set<Node> getAssocSet() {
            return this.assocSet;
        }
    }
}

