/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.Matrix;
import java.util.HashMap;
import java.util.List;

public class GraphWithParameters {
    private Graph graph;
    private final HashMap<Edge, Double> weightHash;
    List<List<Integer>> cycles;

    public GraphWithParameters(SemIm semIm, Graph trueCPDAG) {
        this(trueCPDAG);
        for (Node node : this.getGraph().getNodes()) {
            if (!GraphUtils.allAdjacenciesAreDirected(node, this.getGraph())) continue;
            for (Edge edge : this.getGraph().getEdges(node)) {
                double semImWeight = semIm.getEdgeCoef(edge);
                this.getWeightHash().put(edge, semImWeight);
            }
        }
        this.graph = this.getGraph();
    }

    public GraphWithParameters(Graph graph) {
        this.graph = graph;
        this.weightHash = new HashMap();
    }

    public void addEdge(Node node1, Node node2, double weight) {
        Edge edge = new Edge(node1, node2, Endpoint.TAIL, Endpoint.ARROW);
        this.getGraph().addEdge(edge);
        this.getWeightHash().put(edge, weight);
    }

    public void addEdge(String nodeName1, String nodeName2, double weight) {
        Node node1 = this.getGraph().getNode(nodeName1);
        Node node2 = this.getGraph().getNode(nodeName2);
        this.addEdge(node1, node2, weight);
    }

    public GraphWithParameters(DataSet dataSet) {
        int i;
        Matrix Bmatrix = dataSet.getDoubleData();
        this.graph = new EdgeListGraph();
        this.weightHash = new HashMap();
        int n = Bmatrix.rows();
        for (i = 0; i < n; ++i) {
            this.getGraph().addNode(new GraphNode(dataSet.getVariable(i).getName()));
        }
        for (i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                double value = Bmatrix.get(i, j);
                if (!(value > 1.0E-15) && !(value < -1.0E-15)) continue;
                Node node1 = this.getGraph().getNode(dataSet.getVariableNames().get(i));
                Node node2 = this.getGraph().getNode(dataSet.getVariableNames().get(j));
                Edge edge = new Edge(node1, node2, Endpoint.TAIL, Endpoint.ARROW);
                this.getGraph().addEdge(edge);
                this.getWeightHash().put(edge, value);
            }
        }
    }

    public String toString() {
        StringBuilder str = new StringBuilder();
        for (Edge edge : this.getGraph().getEdges()) {
            str.append(edge.toString());
            str.append("   ").append(this.getWeightHash().get(edge)).append("\n");
        }
        return str.toString();
    }

    public static GraphWithParameters regress(DataSet dataSet, Graph graph) {
        SemPm semPmEstDag = new SemPm(graph);
        SemEstimator estimatorEstDag = new SemEstimator(dataSet, semPmEstDag);
        estimatorEstDag.estimate();
        SemIm semImEstDag = estimatorEstDag.getEstimatedSem();
        return new GraphWithParameters(semImEstDag, graph);
    }

    public DataSet getGraphMatrix() {
        int n = this.getGraph().getNumNodes();
        Matrix matrix = new Matrix(n, n);
        for (Edge edge : this.getGraph().getEdges()) {
            Node node1 = edge.getNode1();
            Node node2 = edge.getNode2();
            int node1Index = this.getGraph().getNodes().indexOf(node1);
            int node2Index = this.getGraph().getNodes().indexOf(node2);
            double value = this.getWeightHash().get(edge);
            matrix.set(node2Index, node1Index, value);
        }
        return new BoxDataSet(new DoubleDataBox(matrix.toArray()), this.getGraph().getNodes());
    }

    public List<List<Integer>> getCycles() {
        return this.cycles;
    }

    public Graph getGraph() {
        return this.graph;
    }

    public HashMap<Edge, Double> getWeightHash() {
        return this.weightHash;
    }
}

