/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.IndTestFisherZ;
import edu.cmu.tetrad.search.SemBicScore;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;
import org.apache.commons.math3.linear.SingularMatrixException;

public class DMSearch {
    private int[] inputs;
    private int[] outputs;
    private double alphaSober = 0.05;
    private double alphaPC = 0.05;
    private double gesDiscount = 10.0;
    private int gesDepth = 0;
    private int minDiscount = 4;
    private boolean useGES = true;
    private int[] trueInputs;
    private DataSet data;
    private CovarianceMatrix cov;
    private LatentStructure dmStructure;

    public void setMinDiscount(int minDiscount) {
        this.minDiscount = minDiscount;
    }

    public int getMinDepth() {
        return this.minDiscount;
    }

    public void setGesDepth(int gesDepth) {
        this.gesDepth = gesDepth;
    }

    public int getGesDepth() {
        return this.gesDepth;
    }

    public void setTrueInputs(int[] trueInputs) {
        this.trueInputs = trueInputs;
    }

    public void setInputs(int[] inputs) {
        this.inputs = inputs;
    }

    public void setOutputs(int[] outputs) {
        this.outputs = outputs;
    }

    public void setData(DataSet data) {
        this.data = data;
    }

    public int[] getTrueInputs() {
        return this.trueInputs;
    }

    public DataSet getData() {
        return this.data;
    }

    public int[] getInputs() {
        return this.inputs;
    }

    public int[] getOutputs() {
        return this.outputs;
    }

    public LatentStructure getDmStructure() {
        return this.dmStructure;
    }

    public void setDmStructure(LatentStructure structure) {
        this.dmStructure = structure;
    }

    public void setAlphaSober(double alpha) {
        this.alphaSober = alpha;
    }

    public void setAlphaPC(double alpha) {
        this.alphaPC = alpha;
    }

    public void setDiscount(double discount) {
        this.gesDiscount = discount;
    }

    public void setUseFgES(boolean set) {
        this.useGES = set;
    }

    public Graph search() {
        int[] trueInputs = this.getTrueInputs();
        DataSet data = this.getData();
        this.cov = new CovarianceMatrix(data);
        Knowledge knowledge = new Knowledge(data.getVariableNames());
        for (int i : this.getInputs()) {
            knowledge.addToTier(0, data.getVariable(i).getName());
        }
        for (int i : this.getOutputs()) {
            knowledge.addToTier(1, data.getVariable(i).getName());
        }
        knowledge.setTierForbiddenWithin(0, true);
        knowledge.setTierForbiddenWithin(1, true);
        HashSet<String> inputString = new HashSet<String>();
        HashSet<Integer> actualInputs = new HashSet<Integer>();
        for (int i = 0; i < trueInputs.length; ++i) {
            actualInputs.add(trueInputs[i]);
        }
        for (int i : this.inputs) {
            if (!actualInputs.contains(i)) continue;
            inputString.add(data.getVariable(i).getName());
        }
        Graph pattern = new EdgeListGraph();
        if (this.useGES) {
            Fges ges = new Fges(new SemBicScore(this.cov));
            pattern = this.recursiveGES(pattern, knowledge, this.gesDiscount, this.getMinDepth(), data, inputString);
        } else {
            this.cov = new CovarianceMatrix(data);
            System.out.println("Running PC Search");
            double penalty = 2.0;
            IndTestFisherZ ind = new IndTestFisherZ(this.cov, this.alphaPC);
            for (int i = 0; i < this.getInputs().length; ++i) {
                if (!pattern.containsNode(data.getVariable(i))) {
                    pattern.addNode(data.getVariable(i));
                }
                if (!actualInputs.contains(i)) continue;
                for (int j = this.getInputs().length; j < data.getNumColumns(); ++j) {
                    if (!pattern.containsNode(data.getVariable(j))) {
                        pattern.addNode(data.getVariable(j));
                    }
                    if (!ind.checkIndependence(data.getVariable(i), data.getVariable(j), new Node[0]).dependent()) continue;
                    pattern.addDirectedEdge(data.getVariable(i), data.getVariable(j));
                }
            }
            System.out.println("Running DM search");
            this.applyDmSearch(pattern, inputString, penalty);
        }
        return this.getDmStructure().latentStructToEdgeListGraph(this.getDmStructure());
    }

    public LatentStructure applyDmSearch(Graph pattern, Set<String> inputString, double penalty) {
        ArrayList<TreeSet<Node>> outputParentsList = new ArrayList<TreeSet<Node>>();
        List<Node> patternNodes = pattern.getNodes();
        Collections.sort(patternNodes, new Comparator<Node>(){

            @Override
            public int compare(Node node1, Node node2) {
                if (node1.getName().length() > node2.getName().length()) {
                    return 1;
                }
                if (node1.getName().length() < node2.getName().length()) {
                    return -1;
                }
                int n1 = Integer.parseInt(node1.getName().substring(1));
                int n2 = Integer.parseInt(node2.getName().substring(1));
                return n1 - n2;
            }
        });
        System.out.println("Sorted patternNodes");
        TreeSet<Node> outputNodes = new TreeSet<Node>();
        for (int i : this.getOutputs()) {
            outputNodes.add(patternNodes.get(i));
        }
        System.out.println("Got output nodes");
        Object object = outputNodes.iterator();
        while (object.hasNext()) {
            Node node = (Node)object.next();
            outputParentsList.add(new TreeSet<Node>(this.getInputParents(node, inputString, pattern)));
        }
        System.out.println("Created list of output node parents");
        int sublistStart = 1;
        int nLatents = 0;
        LatentStructure structure = new LatentStructure();
        for (Set set : outputParentsList) {
            TreeSet<Node> sameSetParents = new TreeSet<Node>(new Comparator<Node>(){

                @Override
                public int compare(Node node1, Node node2) {
                    if (node1.getName().length() > node2.getName().length()) {
                        return 1;
                    }
                    if (node1.getName().length() < node2.getName().length()) {
                        return -1;
                    }
                    int n1 = Integer.parseInt(node1.getName().substring(1));
                    int n2 = Integer.parseInt(node2.getName().substring(1));
                    return n1 - n2;
                }
            });
            List nextSet = outputParentsList.subList(sublistStart, outputParentsList.size());
            if (nextSet.isEmpty()) {
                sameSetParents.addAll(set);
            }
            for (Object set2 : nextSet) {
                if (set.size() != 0 && set2.size() != 0 && set.equals(set2)) {
                    sameSetParents.addAll(set);
                    continue;
                }
                if (set.size() <= 0) continue;
                sameSetParents.addAll(set);
            }
            if (sameSetParents.size() > 0) {
                Object set2;
                GraphNode tempLatent = new GraphNode("L" + nLatents);
                if (this.setContained(structure, structure.inputs.keySet(), sameSetParents) && !structure.inputs.isEmpty()) continue;
                structure.latents.add(tempLatent);
                structure.inputs.put(tempLatent, sameSetParents);
                ++nLatents;
                set2 = outputNodes.iterator();
                while (set2.hasNext()) {
                    Node node = (Node)set2.next();
                    if (!new TreeSet<Node>(this.getInputParents(node, inputString, pattern)).equals(sameSetParents)) continue;
                    if (structure.outputs.get(tempLatent) == null) {
                        TreeSet<Node> outputNode = new TreeSet<Node>();
                        outputNode.add(node);
                        structure.outputs.put(tempLatent, outputNode);
                        continue;
                    }
                    structure.outputs.get(tempLatent).add(node);
                }
            }
            System.out.println("Completed starting point: " + sublistStart + " out of #" + outputParentsList.size() + " sets, and is " + set.size() + " units large.");
            ++sublistStart;
        }
        System.out.println("created initial sets");
        TreeMap latentsSortedByInputSetSize = this.sortMapByValue(structure.inputs, structure.latents, structure);
        System.out.println("Finding initial latent-latent effects");
        TreeSet<Node> treeSet = new TreeSet<Node>(new Comparator<Node>(){

            @Override
            public int compare(Node node1, Node node2) {
                if (node1.getName().length() > node2.getName().length()) {
                    return 1;
                }
                if (node1.getName().length() < node2.getName().length()) {
                    return -1;
                }
                int n1 = Integer.parseInt(node1.getName().substring(1));
                int n2 = Integer.parseInt(node2.getName().substring(1));
                return n1 - n2;
            }
        });
        TreeSet<Node> inputs2 = new TreeSet<Node>(new Comparator<Node>(){

            @Override
            public int compare(Node node1, Node node2) {
                if (node1.getName().length() > node2.getName().length()) {
                    return 1;
                }
                if (node1.getName().length() < node2.getName().length()) {
                    return -1;
                }
                int n1 = Integer.parseInt(node1.getName().substring(1));
                int n2 = Integer.parseInt(node2.getName().substring(1));
                return n1 - n2;
            }
        });
        HashSet<Object> alreadyLookedAt = new HashSet<Object>();
        for (int i = 0; i <= latentsSortedByInputSetSize.keySet().size(); ++i) {
            Iterator<Node> sortedInputs = new TreeSet<TreeSet<Node>>(new Comparator<TreeSet<Node>>(){

                @Override
                public int compare(TreeSet<Node> o1, TreeSet<Node> o2) {
                    int size = o1.size() - o2.size();
                    if (size == 0) {
                        if (o1.equals(o2)) {
                            return 0;
                        }
                        return o1.hashCode() - o2.hashCode();
                    }
                    return size;
                }
            });
            ((TreeSet)((Object)sortedInputs)).addAll(latentsSortedByInputSetSize.keySet());
            TreeSet<Node> treeSet2 = this.findFirstUnseenElement((TreeSet<TreeSet<Node>>)((Object)sortedInputs), alreadyLookedAt, latentsSortedByInputSetSize);
            HashSet<Node> alreadyLookedAtInnerLoop = new HashSet<Node>();
            Node latent1 = latentsSortedByInputSetSize.get(treeSet2);
            if (treeSet2.first().getName().equals("alreadySeenEverything")) continue;
            for (int j = 0; j <= latentsSortedByInputSetSize.keySet().size(); ++j) {
                TreeSet<TreeSet<Node>> temp2 = new TreeSet<TreeSet<Node>>(new Comparator<TreeSet<Node>>(){

                    @Override
                    public int compare(TreeSet<Node> o1, TreeSet<Node> o2) {
                        int size = o1.size() - o2.size();
                        if (size == 0) {
                            if (o1.equals(o2)) {
                                return 0;
                            }
                            return o1.hashCode() - o2.hashCode();
                        }
                        return size;
                    }
                });
                inputs2 = this.findFirstUnseenElement((TreeSet<TreeSet<Node>>)((Object)sortedInputs), alreadyLookedAtInnerLoop, latentsSortedByInputSetSize);
                Node latent2 = latentsSortedByInputSetSize.get(inputs2);
                if (inputs2.first().getName().equals("alreadySeenEverything")) continue;
                alreadyLookedAtInnerLoop.add(latent2);
                if (latent1.equals(latent2) || structure.getInputs(latent2).equals(structure.getInputs(latent1)) || !structure.getInputs(latent2).containsAll(structure.getInputs(latent1))) continue;
                if (structure.latentEffects.get(latent1) == null) {
                    TreeSet<Node> latentEffects = new TreeSet<Node>(new Comparator<Node>(){

                        @Override
                        public int compare(Node node1, Node node2) {
                            if (node1.getName().length() > node2.getName().length()) {
                                return 1;
                            }
                            if (node1.getName().length() < node2.getName().length()) {
                                return -1;
                            }
                            int n1 = Integer.parseInt(node1.getName().substring(1));
                            int n2 = Integer.parseInt(node2.getName().substring(1));
                            return n1 - n2;
                        }
                    });
                    latentEffects.add(latent2);
                    structure.latentEffects.put(latent1, latentEffects);
                } else {
                    structure.latentEffects.get(latent1).add(latent2);
                }
                latentsSortedByInputSetSize = this.removeSetInputs(structure, structure.getInputs(latent1), structure.getInputs(latent2).size(), latent2, latentsSortedByInputSetSize);
            }
            alreadyLookedAt.add(latent1);
        }
        TreeSet<Node> emptyTreeSet = new TreeSet<Node>(new Comparator<Node>(){

            @Override
            public int compare(Node node1, Node node2) {
                if (node1.getName().length() > node2.getName().length()) {
                    return 1;
                }
                if (node1.getName().length() < node2.getName().length()) {
                    return -1;
                }
                int n1 = Integer.parseInt(node1.getName().substring(1));
                int n2 = Integer.parseInt(node2.getName().substring(1));
                return n1 - n2;
            }
        });
        for (Node latent : structure.getLatents()) {
            if (structure.latentEffects.get(latent) != null) continue;
            structure.latentEffects.put(latent, emptyTreeSet);
        }
        System.out.println("Structure prior to Sober's step:");
        System.out.println("Applying Sober's step ");
        for (Node latent : structure.getLatents()) {
            if (!structure.latentEffects.containsKey(latent)) continue;
            for (Node latentEffect : structure.getLatentEffects(latent)) {
                this.applySobersStep(structure.getInputs(latent), structure.getInputs(latentEffect), structure.getOutputs(latent), structure.getOutputs(latentEffect), pattern, structure, latent, latentEffect);
            }
        }
        this.setDmStructure(structure);
        File file = new File("src/edu/cmu/tetradproj/amurrayw/DM_output_GES_penalty" + penalty + "_.txt");
        try {
            FileOutputStream out = new FileOutputStream(file);
            PrintStream outStream = new PrintStream(out);
            outStream.println(structure.latentStructToEdgeListGraph(structure));
        }
        catch (FileNotFoundException e) {
            System.out.println("Can't write to file.");
        }
        return structure;
    }

    private TreeSet<Node> findFirstUnseenElement(TreeSet<TreeSet<Node>> set, HashSet alreadySeen, TreeMap map) {
        for (TreeSet<Node> currentSet : set) {
            if (alreadySeen.contains(map.get(currentSet)) || map.get(currentSet) == null) continue;
            return currentSet;
        }
        GraphNode end = new GraphNode("alreadySeenEverything");
        TreeSet<Node> seenEverything = new TreeSet<Node>();
        seenEverything.add(end);
        return seenEverything;
    }

    private TreeSet nthElementOn(TreeSet set, int startingElementPos) {
        for (int i = 0; i < set.size() - startingElementPos; ++i) {
            set = this.rest(set);
        }
        return set;
    }

    private TreeSet<TreeSet<Node>> rest(TreeSet set) {
        set.remove(set.first());
        return set;
    }

    private TreeSet<TreeSet<Node>> second(TreeSet<TreeSet<Node>> set) {
        TreeSet<Object> secondNodeSet = new TreeSet();
        secondNodeSet = this.rest(set);
        secondNodeSet.first();
        return secondNodeSet;
    }

    private boolean allEqual(SortedSet<Node> set1, SortedSet<Node> set2) {
        for (Node i : set1) {
            for (Node j : set2) {
                if (i.equals(j)) continue;
                return false;
            }
        }
        for (Node i : set2) {
            for (Node j : set1) {
                if (i.equals(j)) continue;
                return false;
            }
        }
        return true;
    }

    private Graph recursiveGES(Graph previousGES, Knowledge knowledge, double penalty, double minPenalty, DataSet data, Set<String> inputString) {
        for (Edge edge : previousGES.getEdges()) {
            knowledge.setRequired(edge.getNode1().getName(), edge.getNode2().getName());
        }
        previousGES = null;
        this.cov = new CovarianceMatrix(data);
        Fges ges = new Fges(new SemBicScore(this.cov));
        ges.setKnowledge(knowledge);
        Graph pattern = ges.search();
        File file = new File("src/edu/cmu/tetradproj/amurrayw/ges_output_" + penalty + "_.txt");
        try {
            FileOutputStream out = new FileOutputStream(file);
            PrintStream outStream = new PrintStream(out);
            outStream.println(pattern);
        }
        catch (FileNotFoundException e) {
            System.out.println("Can't write to file.");
        }
        if (penalty > minPenalty) {
            this.applyDmSearch(pattern, inputString, penalty);
            return this.recursiveGES(pattern, knowledge, penalty - 1.0, minPenalty, data, inputString);
        }
        this.applyDmSearch(pattern, inputString, penalty);
        return pattern;
    }

    private void applySobersStep(SortedSet<Node> inputsLatent, SortedSet<Node> inputsLatentEffect, SortedSet<Node> outputsLatent, SortedSet<Node> outputsLatentEffect, Graph pattern, LatentStructure structure, Node latent, Node latentEffect) {
        ArrayList<Node> latentList = new ArrayList<Node>();
        latentList.addAll(inputsLatent);
        IndTestFisherZ test = new IndTestFisherZ(this.cov, this.alphaSober);
        boolean testResult = false;
        try {
            testResult = test.checkIndependence(outputsLatent.first(), outputsLatentEffect.first(), latentList).independent();
        }
        catch (SingularMatrixException error) {
            System.out.println(error);
            System.out.println("SingularMatrixException Error!!!!!! Evaluated as:");
            System.out.println(testResult);
            System.out.println("outputsLatent.first()");
            System.out.println(outputsLatent.first());
            System.out.println("outputsLatentEffect.first()");
            System.out.println(outputsLatentEffect.first());
        }
        if (testResult) {
            structure.latentEffects.get(latent).remove(latentEffect);
            structure.inputs.get(latentEffect).addAll(inputsLatent);
        }
    }

    private TreeMap removeSetInputs(LatentStructure structure, SortedSet<Node> set, int sizeOfSuperset, Node latentForSuperset, TreeMap<TreeSet<Node>, Node> map) {
        for (Node latent : structure.latents) {
            if (structure.inputs.get(latent).equals(set) || structure.inputs.get(latent).size() <= sizeOfSuperset && !latent.equals(latentForSuperset) || !structure.inputs.get(latent).containsAll(set)) continue;
            structure.inputs.get(latent).removeAll(set);
        }
        return map;
    }

    private boolean setContained(LatentStructure structure, Set<Node> latentSet, Set<Node> inputSet) {
        for (Node latent : latentSet) {
            if (!structure.getInputs(latent).equals(inputSet)) continue;
            return true;
        }
        return false;
    }

    public boolean equals(Object obj) {
        return super.equals(obj);
    }

    private SortedSet copy(SortedSet orig) {
        TreeSet newset = new TreeSet();
        for (Object o : orig) {
            newset.add(o);
        }
        return newset;
    }

    private TreeMap<TreeSet<Node>, Node> sortMapByValue(Map<Node, SortedSet<Node>> inputMap, List<Node> latents, LatentStructure structure) {
        TreeMap<SortedSet<Node>, Node> sortedInputSets = new TreeMap<SortedSet<Node>, Node>(new Comparator<SortedSet<Node>>(){

            @Override
            public int compare(SortedSet o1, SortedSet o2) {
                int size = o1.size() - o2.size();
                if (size == 0) {
                    if (o1.equals(o2)) {
                        return 0;
                    }
                    return o1.hashCode() - o2.hashCode();
                }
                return size;
            }
        });
        for (Node latent : latents) {
            TreeSet<Node> tempSet = new TreeSet<Node>(new Comparator<Node>(){

                @Override
                public int compare(Node node1, Node node2) {
                    if (node1.getName().length() > node2.getName().length()) {
                        return 1;
                    }
                    if (node1.getName().length() < node2.getName().length()) {
                        return -1;
                    }
                    int n1 = Integer.parseInt(node1.getName().substring(1));
                    int n2 = Integer.parseInt(node2.getName().substring(1));
                    return n1 - n2;
                }
            });
            tempSet.addAll((Collection<Node>)inputMap.get(latent));
            sortedInputSets.put(tempSet, latent);
        }
        return sortedInputSets;
    }

    private Set<Node> getInputParents(Node node, Set inputString, Graph pattern) {
        HashSet<Node> actualInputs = new HashSet<Node>();
        for (Node posInput : pattern.getAdjacentNodes(node)) {
            if (!inputString.contains(posInput.getName())) continue;
            actualInputs.add(posInput);
        }
        return actualInputs;
    }

    public class LatentStructure {
        List<Node> latents = new ArrayList<Node>();
        Map<Node, SortedSet<Node>> inputs = new TreeMap<Node, SortedSet<Node>>();
        Map<Node, SortedSet<Node>> outputs = new TreeMap<Node, SortedSet<Node>>();
        Map<Node, SortedSet<Node>> latentEffects = new TreeMap<Node, SortedSet<Node>>();

        public void addRecord(Node latent, SortedSet<Node> inputs, SortedSet<Node> outputs, SortedSet<Node> latentEffects) {
            if (this.latents.contains(latent)) {
                throw new IllegalArgumentException();
            }
            this.latents.add(latent);
            this.inputs.put(latent, inputs);
            this.outputs.put(latent, outputs);
            this.latentEffects.put(latent, latentEffects);
        }

        public void removeLatent(Node latent) {
            this.latents.remove(latent);
            this.inputs.remove(latent);
            this.outputs.remove(latent);
            this.latentEffects.remove(latent);
        }

        public List<Node> getLatents() {
            return new ArrayList<Node>(this.latents);
        }

        public boolean containsLatent(Node latent) {
            return this.latents.contains(latent);
        }

        public SortedSet<Node> getInputs(Node latent) {
            return new TreeSet<Node>(this.inputs.get(latent));
        }

        public SortedSet<Node> getOutputs(Node latent) {
            return new TreeSet<Node>(this.outputs.get(latent));
        }

        public SortedSet<Node> getLatentEffects(Node latent) {
            return new TreeSet<Node>(this.latentEffects.get(latent));
        }

        public String toString() {
            StringBuilder b = new StringBuilder();
            for (Node node : this.latents) {
                b.append("Latent:" + node + "\n Inputs:" + this.inputs.get(node) + "\n Outputs:" + this.outputs.get(node) + "\n Latent Effects:" + this.latentEffects.get(node) + "\t\n");
            }
            b.append("\n");
            return b.toString();
        }

        public Graph latentStructToEdgeListGraph(LatentStructure structure) {
            EdgeListGraph structureGraph = new EdgeListGraph();
            for (Node latent : this.latents) {
                structureGraph.addNode(latent);
                for (Node input : this.inputs.get(latent)) {
                    structureGraph.addNode(input);
                }
                for (Node output : this.outputs.get(latent)) {
                    structureGraph.addNode(output);
                }
                for (Node input : this.inputs.get(latent)) {
                    structureGraph.addDirectedEdge(input, latent);
                }
                for (Node output : this.outputs.get(latent)) {
                    structureGraph.addDirectedEdge(latent, output);
                }
                if (this.latentEffects.get(latent) == null) continue;
                for (Node latentEff : this.latentEffects.get(latent)) {
                    if (!structureGraph.containsNode(latentEff)) {
                        structureGraph.addNode(latentEff);
                    }
                    structureGraph.addDirectedEdge(latent, latentEff);
                }
            }
            return structureGraph;
        }
    }
}

