/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.KnowledgeEdge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.search.DagInPatternIterator;
import edu.cmu.tetrad.search.Ges;
import edu.cmu.tetrad.search.IndTestMimBuild;
import edu.cmu.tetrad.search.MeekRules;
import edu.cmu.tetrad.search.MimAdjacencySearch;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.search.SepsetMap;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.sem.SemOptimizerEm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public final class MimBuild {
    public static final String LATENT_PREFIX = "_L";
    private List<Node> latents;
    private IndTestMimBuild indTest;
    private Knowledge knowledge;
    private TetradLogger logger = TetradLogger.getInstance();

    public MimBuild(IndTestMimBuild indTest, Knowledge knowledge) {
        this.latents = new ArrayList<Node>();
        this.indTest = indTest;
        this.knowledge = knowledge;
    }

    public Graph search() {
        if (this.getIndTest() == null) {
            throw new NullPointerException();
        }
        int type = this.getIndTest().getAlgorithmType();
        if (type == 0 || type == -1) {
            Graph graph = this.mimBuildGesSearch();
            this.logger.log("graph", "\nReturning this graph: " + graph);
            return graph;
        }
        Graph graph = this.mimBuildPcSearch();
        this.logger.log("graph", "\nReturning this graph: " + graph);
        return graph;
    }

    public static String[] getTestDescriptions() {
        String[] tests = new String[]{"Gaussian maximum likelihood", "Two-stage least squares"};
        return tests;
    }

    public static String[] getAlgorithmDescriptions() {
        String[] labels = new String[]{"GES", "PC"};
        return labels;
    }

    public static List<String> generateLatentNames(int total) {
        ArrayList<String> output = new ArrayList<String>();
        for (int i = 0; i < total; ++i) {
            output.add(LATENT_PREFIX + (i + 1));
        }
        return output;
    }

    private Graph mimBuildPcSearch() {
        EdgeListGraph graph = new EdgeListGraph(this.getIndTest().getVariableList());
        this.startMeasurementModel(graph);
        MimAdjacencySearch adj = new MimAdjacencySearch(graph, this.getIndTest(), this.getKnowledge(), this.latents);
        SepsetMap sepset = adj.adjSearch();
        EdgeListGraph latent_graph = new EdgeListGraph(this.latents);
        for (int i = 0; i < this.latents.size(); ++i) {
            for (int j = i + 1; j < this.latents.size(); ++j) {
                latent_graph.addUndirectedEdge(this.latents.get(i), this.latents.get(j));
            }
        }
        TetradLogger.getInstance().log("info", "Starting PC Orientation.");
        SearchGraphUtils.pcOrientbk(this.getKnowledge(), latent_graph, graph.getNodes());
        SearchGraphUtils.orientCollidersUsingSepsets(sepset, this.getKnowledge(), latent_graph);
        MeekRules rules = new MeekRules();
        rules.setKnowledge(this.getKnowledge());
        rules.orientImplied(latent_graph);
        TetradLogger.getInstance().log("info", "Finishing PC Orientation");
        for (Node current : this.latents) {
            for (Node ad_node : latent_graph.getNodesInTo(current, Endpoint.ARROW)) {
                graph.setEndpoint(ad_node, current, Endpoint.ARROW);
            }
        }
        return graph;
    }

    private Graph mimBuildGesSearch() {
        double score;
        Graph graph = new EdgeListGraph(this.getIndTest().getVariableList());
        this.startMeasurementModel(graph);
        String[] varNames = new String[graph.getNumNodes()];
        for (int i = 0; i < varNames.length; ++i) {
            varNames[i] = ((Object)graph.getNodes().get(i)).toString();
        }
        String[] latentVarNames = new String[this.latents.size()];
        for (int i = 0; i < this.latents.size(); ++i) {
            latentVarNames[i] = ((Object)this.latents.get(i)).toString();
        }
        CovarianceMatrix covMatrix = this.getIndTest().getCovMatrix();
        DagInPatternIterator iterator = new DagInPatternIterator(graph);
        iterator.setKnowledge(this.getKnowledge());
        graph = iterator.next();
        SemOptimizerEm optimizer = new SemOptimizerEm();
        SemEstimator estimator = new SemEstimator(covMatrix, new SemPm(graph), (SemOptimizer)optimizer);
        estimator.estimate();
        double newScore = this.scoreModel(estimator.getEstimatedSem());
        do {
            List<Node> continuousVariables = DataUtils.createContinuousVariables(varNames);
            DenseDoubleMatrix2D oldExpectedCovariance = new DenseDoubleMatrix2D(optimizer.getExpectedCovarianceMatrix());
            int sampleSize = covMatrix.getSampleSize();
            CovarianceMatrix expectedCovarianceMatrix = new CovarianceMatrix(continuousVariables, oldExpectedCovariance, sampleSize);
            CovarianceMatrix newCovMatrix = expectedCovarianceMatrix.getSubmatrix(latentVarNames);
            score = newScore;
            Ges ges = new Ges(newCovMatrix);
            ges.setKnowledge(this.getKnowledge());
            Graph newStructuralModel = ges.search();
            if (this.getKnowledge().isViolatedBy(newStructuralModel)) {
                System.out.println("VIOLATED1!");
            }
            System.out.println(((Object)newStructuralModel).toString());
            Graph directedStructuralModel = new EdgeListGraph(newStructuralModel);
            if (this.getKnowledge().isViolatedBy(directedStructuralModel)) {
                System.out.println("VIOLATED2!");
            }
            iterator = new DagInPatternIterator(directedStructuralModel);
            iterator.setKnowledge(this.getKnowledge());
            directedStructuralModel = iterator.next();
            if (this.getKnowledge().isViolatedBy(directedStructuralModel)) {
                System.out.println("VIOLATED3!");
            }
            Graph newCandidate = this.getUpdatedGraph(graph, directedStructuralModel);
            if (this.getKnowledge().isViolatedBy(newCandidate)) {
                System.out.println("VIOLATED4!");
            }
            estimator = new SemEstimator(covMatrix, new SemPm(newCandidate), (SemOptimizer)optimizer);
            estimator.estimate();
            newScore = this.scoreModel(estimator.getEstimatedSem());
            if (!(newScore > score)) continue;
            graph = this.getUpdatedGraph(graph, newStructuralModel);
            if (!this.getKnowledge().isViolatedBy(graph)) continue;
            System.out.println("VIOLATED5!");
        } while (newScore > score);
        System.out.println("Yes, I got here!!!");
        System.out.println(graph);
        return graph;
    }

    private Graph getUpdatedGraph(Graph graph, Graph structuralModel) {
        EdgeListGraph output = new EdgeListGraph(graph);
        ArrayList<Edge> edgesToRemove = new ArrayList<Edge>();
        for (Edge nextEdge : output.getEdges()) {
            if (nextEdge.getNode1().getNodeType() != NodeType.LATENT || nextEdge.getNode2().getNodeType() != NodeType.LATENT) continue;
            edgesToRemove.add(nextEdge);
        }
        output.removeEdges(edgesToRemove);
        for (Edge nextEdge : structuralModel.getEdges()) {
            Node node1 = output.getNode(((Object)nextEdge.getNode1()).toString());
            Node node2 = output.getNode(((Object)nextEdge.getNode2()).toString());
            output.setEndpoint(node2, node1, nextEdge.getEndpoint1());
            output.setEndpoint(node1, node2, nextEdge.getEndpoint2());
        }
        return output;
    }

    private double scoreModel(SemIm semIm) {
        double fml = semIm.getFml();
        int freeParams = semIm.getNumFreeParams();
        int sampleSize = semIm.getSampleSize();
        return -fml - (double)freeParams * Math.log(sampleSize);
    }

    private void startMeasurementModel(Graph graph) {
        Iterator<KnowledgeEdge> it = this.getIndTest().getMeasurements().requiredEdgesIterator();
        while (it.hasNext()) {
            KnowledgeEdge edge = it.next();
            Node x = graph.getNode(edge.getFrom());
            Node y = graph.getNode(edge.getTo());
            graph.addDirectedEdge(x, y);
            if (this.latents.contains(x)) continue;
            this.latents.add(x);
            x.setNodeType(NodeType.LATENT);
        }
        int size = this.latents.size();
        Iterator<Node> itl = this.latents.iterator();
        Node[] nodes = new Node[size];
        int count = 0;
        while (itl.hasNext()) {
            nodes[count++] = itl.next();
        }
        for (int i = 0; i < size; ++i) {
            for (int j = i + 1; j < size; ++j) {
                if (this.knowledge.edgeForbidden(nodes[i].getName(), nodes[j].getName())) continue;
                graph.addUndirectedEdge(nodes[i], nodes[j]);
            }
        }
    }

    private IndTestMimBuild getIndTest() {
        return this.indTest;
    }

    private Knowledge getKnowledge() {
        return this.knowledge;
    }
}

