/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.SublistGenerator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;

public class Ida {
    private final DataSet dataSet;
    private final Graph pattern;
    private final List<Node> possibleCauses;
    private final Map<String, Integer> nodeIndices;
    private final ICovarianceMatrix allCovariances;

    public Ida(DataSet dataSet, Graph cpdag, List<Node> possibleCauses) {
        this.dataSet = DataTransforms.convertNumericalDiscreteToContinuous(dataSet);
        this.pattern = cpdag;
        possibleCauses = GraphUtils.replaceNodes(possibleCauses, dataSet.getVariables());
        this.possibleCauses = possibleCauses;
        this.allCovariances = new CovarianceMatrix(this.dataSet);
        this.nodeIndices = new HashMap<String, Integer>();
        for (int i = 0; i < cpdag.getNodes().size(); ++i) {
            this.nodeIndices.put(cpdag.getNodes().get(i).getName(), i);
        }
    }

    public NodeEffects getSortedMinEffects(Node y) {
        Map<Node, Double> allEffects = this.calculateMinimumEffectsOnY(y);
        ArrayList<Node> nodes = new ArrayList<Node>(allEffects.keySet());
        RandomUtil.shuffle(nodes);
        nodes.sort((o1, o2) -> Double.compare(FastMath.abs((Double)allEffects.get(o2)), FastMath.abs((Double)allEffects.get(o1))));
        LinkedList<Double> effects = new LinkedList<Double>();
        for (Node node : nodes) {
            effects.add(allEffects.get(node));
        }
        return new NodeEffects(nodes, effects);
    }

    public double trueEffect(Node x, Node y, Graph trueDag) {
        if (x == y) {
            throw new IllegalArgumentException("x == y");
        }
        if (!trueDag.paths().isAncestorOf(x, y)) {
            return 0.0;
        }
        trueDag = GraphUtils.replaceNodes(trueDag, this.dataSet.getVariables());
        ArrayList<Node> regressors = new ArrayList<Node>();
        regressors.add(x);
        regressors.addAll(trueDag.getParents(x));
        return FastMath.abs(this.getBeta(regressors, y));
    }

    public double distance(LinkedList<Double> effects, double trueEffect) {
        if ((effects = new LinkedList<Double>(effects)).isEmpty()) {
            return Double.NaN;
        }
        if (effects.size() == 1) {
            double effect = effects.get(0);
            return FastMath.abs(effect - trueEffect);
        }
        Collections.sort(effects);
        double min = effects.getFirst();
        double max = effects.getLast();
        if (trueEffect >= min && trueEffect <= max) {
            return 0.0;
        }
        double m1 = FastMath.abs(trueEffect - min);
        double m2 = FastMath.abs(trueEffect - max);
        return FastMath.min(m1, m2);
    }

    private LinkedList<Double> getEffects(Node x, Node y) {
        int[] choice;
        List<Node> parents = this.pattern.getParents(x);
        List<Node> children = this.pattern.getChildren(x);
        ArrayList<Node> siblings = new ArrayList<Node>(this.pattern.getAdjacentNodes(x));
        siblings.removeAll(parents);
        siblings.removeAll(children);
        SublistGenerator gen = new SublistGenerator(siblings.size(), siblings.size());
        LinkedList<Double> effects = new LinkedList<Double>();
        block2: while ((choice = gen.next()) != null) {
            try {
                List<Node> sibbled = GraphUtils.asList(choice, siblings);
                if (sibbled.size() > 1) {
                    int[] choice2;
                    ChoiceGenerator gen2 = new ChoiceGenerator(sibbled.size(), 2);
                    while ((choice2 = gen2.next()) != null) {
                        List<Node> adj = GraphUtils.asList(choice2, sibbled);
                        if (this.pattern.isAdjacentTo((Node)adj.get(0), (Node)adj.get(1))) continue;
                        continue block2;
                    }
                }
                if (!sibbled.isEmpty()) {
                    for (Node p : parents) {
                        for (Node s : sibbled) {
                            if (this.pattern.isAdjacentTo(p, s)) continue;
                            continue block2;
                        }
                    }
                }
                ArrayList<Node> regressors = new ArrayList<Node>();
                regressors.add(x);
                for (Node n : parents) {
                    if (regressors.contains(n)) continue;
                    regressors.add(n);
                }
                for (Node n : sibbled) {
                    if (regressors.contains(n)) continue;
                    regressors.add(n);
                }
                if (regressors.contains(y)) {
                    effects.add(0.0);
                    continue;
                }
                effects.add(FastMath.abs(this.getBeta(regressors, y)));
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        Collections.sort(effects);
        return effects;
    }

    public Map<Node, Double> calculateMinimumEffectsOnY(Node y) {
        TreeMap<Node, Double> minEffects = new TreeMap<Node, Double>();
        for (Node x : this.possibleCauses) {
            if (!this.pattern.containsNode(x) || !this.pattern.containsNode(y)) continue;
            LinkedList<Double> effects = this.getEffects(x, y);
            minEffects.put(x, effects.getFirst());
        }
        return minEffects;
    }

    private double getBeta(List<Node> regressors, Node child) {
        try {
            int yIndex = this.nodeIndices.get(child.getName());
            int[] xIndices = new int[regressors.size()];
            for (int i = 0; i < regressors.size(); ++i) {
                xIndices[i] = this.nodeIndices.get(regressors.get(i).getName());
            }
            Matrix rX = this.allCovariances.getSelection(xIndices, xIndices);
            Matrix rY = this.allCovariances.getSelection(xIndices, new int[]{yIndex});
            Matrix bStar = null;
            try {
                bStar = rX.inverse().times(rY);
            }
            catch (SingularMatrixException e) {
                System.out.println("Singularity encountered when regressing " + LogUtilsSearch.getScoreFact(child, regressors));
            }
            return bStar != null ? bStar.get(0, 0) : 0.0;
        }
        catch (SingularMatrixException e) {
            throw new RuntimeException("Singularity encountered when regressing " + LogUtilsSearch.getScoreFact(child, regressors));
        }
    }

    public static class NodeEffects {
        private List<Node> nodes;
        private LinkedList<Double> effects;

        NodeEffects(List<Node> nodes, LinkedList<Double> effects) {
            this.setNodes(nodes);
            this.setEffects(effects);
        }

        public List<Node> getNodes() {
            return this.nodes;
        }

        public void setNodes(List<Node> nodes) {
            this.nodes = nodes;
        }

        public LinkedList<Double> getEffects() {
            return this.effects;
        }

        public void setEffects(LinkedList<Double> effects) {
            this.effects = effects;
        }

        public String toString() {
            StringBuilder b = new StringBuilder();
            for (int i = 0; i < this.nodes.size(); ++i) {
                b.append(this.nodes.get(i)).append("=").append(this.effects.get(i)).append(" ");
            }
            return b.toString();
        }
    }
}

