/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.DistributionType;
import edu.cmu.tetrad.sem.ParamComparison;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.ParameterPair;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetrad.util.dist.Split;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

public final class SemPm2
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private SemGraph graph;
    private List<Node> nodes;
    private List<Parameter> parameters;
    private List<Node> variableNodes;
    private Map<Node, DistributionType> distributionTypes;
    private HashMap<Node, List<Parameter>> distributionParameters;
    private Map<ParameterPair, ParamComparison> paramComparisons = new HashMap<ParameterPair, ParamComparison>();
    private int bIndex = 0;
    private int dIndex = 0;

    public SemPm2(Graph graph) {
        this(new SemGraph(graph));
    }

    public SemPm2(SemGraph graph) {
        if (graph == null) {
            throw new NullPointerException("Graph must not be null.");
        }
        if (graph.existsDirectedCycle()) {
            throw new NullPointerException("Graph must be acyclic. Only the DAG case is considered.");
        }
        for (Edge edge : graph.getEdges()) {
            if (!Edges.isBidirectedEdge(edge)) continue;
            throw new NullPointerException("Graph must not contain bidirected edges. Only the DAG case is considered.");
        }
        this.graph = graph;
        this.graph.setShowErrorTerms(false);
        this.initializeNodes(graph);
        this.initializeVariableNodes();
        this.initializeParams();
    }

    public SemPm2(SemPm2 semPm2) {
        this.graph = semPm2.graph;
        this.nodes = new LinkedList<Node>(semPm2.nodes);
        this.parameters = new LinkedList<Parameter>(semPm2.parameters);
        this.variableNodes = new LinkedList<Node>(semPm2.variableNodes);
        this.distributionTypes = new HashMap<Node, DistributionType>(semPm2.distributionTypes);
        this.distributionParameters = new HashMap<Node, List<Parameter>>(semPm2.distributionParameters);
        this.paramComparisons = new HashMap<ParameterPair, ParamComparison>(semPm2.paramComparisons);
        this.bIndex = semPm2.bIndex;
        this.dIndex = semPm2.dIndex;
    }

    public static SemPm2 serializableInstance() {
        Dag dag = new Dag();
        GraphNode node1 = new GraphNode("X");
        dag.addNode(node1);
        return new SemPm2(Dag.serializableInstance());
    }

    public SemGraph getGraph() {
        return this.graph;
    }

    public List<Parameter> getParameters() {
        return new ArrayList<Parameter>(this.parameters);
    }

    public List<Node> getVariableNodes() {
        return this.variableNodes;
    }

    public List<Node> getErrorNodes() {
        ArrayList<Node> errorNodes = new ArrayList<Node>();
        for (Node node1 : this.nodes) {
            if (node1.getNodeType() != NodeType.ERROR) continue;
            errorNodes.add(node1);
        }
        return errorNodes;
    }

    public List<Node> getMeasuredNodes() {
        ArrayList<Node> measuredNodes = new ArrayList<Node>();
        for (Node variable : this.getVariableNodes()) {
            if (variable.getNodeType() != NodeType.MEASURED) continue;
            measuredNodes.add(variable);
        }
        return measuredNodes;
    }

    public List<Node> getLatentNodes() {
        ArrayList<Node> latentNodes = new ArrayList<Node>();
        for (Node node1 : this.nodes) {
            if (node1.getNodeType() != NodeType.LATENT) continue;
            latentNodes.add(node1);
        }
        return latentNodes;
    }

    public Parameter getParameter(String name) {
        for (Parameter parameter1 : this.getParameters()) {
            if (!name.equals(parameter1.getName())) continue;
            return parameter1;
        }
        return null;
    }

    public Parameter getVarianceParameter(Node node) {
        if (!this.getGraph().isExogenous(node)) {
            return null;
        }
        node = this.getGraph().getVarNode(node);
        for (Parameter parameter : this.parameters) {
            Node _nodeA = parameter.getNodeA();
            Node _nodeB = parameter.getNodeB();
            if (node != _nodeA || node != _nodeB || parameter.getType() != ParamType.VAR) continue;
            return parameter;
        }
        return null;
    }

    public Parameter getCovarianceParameter(Node nodeA, Node nodeB) {
        nodeA = this.getGraph().getVarNode(nodeA);
        nodeB = this.getGraph().getVarNode(nodeB);
        for (Parameter parameter : this.parameters) {
            Node _nodeA = parameter.getNodeA();
            Node _nodeB = parameter.getNodeB();
            if (nodeA == _nodeA && nodeB == _nodeB && parameter.getType() == ParamType.COVAR) {
                return parameter;
            }
            if (nodeB != _nodeA || nodeA != _nodeB || parameter.getType() != ParamType.COVAR) continue;
            return parameter;
        }
        return null;
    }

    public Parameter getCoefficientParameter(Node nodeA, Node nodeB) {
        for (Parameter parameter : this.parameters) {
            Node _nodeA = parameter.getNodeA();
            Node _nodeB = parameter.getNodeB();
            if (nodeA != _nodeA || nodeB != _nodeB || parameter.getType() != ParamType.COEF) continue;
            return parameter;
        }
        return null;
    }

    public Parameter getMeanParameter(Node node) {
        for (Parameter parameter : this.parameters) {
            Node _nodeA = parameter.getNodeA();
            Node _nodeB = parameter.getNodeB();
            if (node != _nodeA || node != _nodeB || parameter.getType() != ParamType.MEAN) continue;
            return parameter;
        }
        return null;
    }

    public String[] getMeasuredVarNames() {
        List<Node> semPmVars = this.getVariableNodes();
        ArrayList<String> varNamesList = new ArrayList<String>();
        for (Node semPmVar : semPmVars) {
            if (semPmVar.getNodeType() != NodeType.MEASURED) continue;
            varNamesList.add(((Object)semPmVar).toString());
        }
        return varNamesList.toArray(new String[0]);
    }

    public void setDistributionType(Node node, DistributionType type) {
        this.removeDistribution(node);
        this.addExogenousDistribution(node, type);
    }

    public DistributionType getDistributionType(Node node) {
        return this.distributionTypes.get(node);
    }

    public List<Parameter> getDistributionParameters(Node node) {
        return this.distributionParameters.get(node);
    }

    public ParamComparison getParamComparison(Parameter a, Parameter b) {
        if (a == null || b == null) {
            throw new NullPointerException();
        }
        ParameterPair pair1 = new ParameterPair(a, b);
        ParameterPair pair2 = new ParameterPair(b, a);
        if (this.paramComparisons.containsKey(pair1)) {
            return this.paramComparisons.get(pair1);
        }
        if (this.paramComparisons.containsKey(pair2)) {
            return this.paramComparisons.get(pair2);
        }
        return ParamComparison.NC;
    }

    public void setParamComparison(Parameter a, Parameter b, ParamComparison comparison) {
        if (a == null || b == null || comparison == null) {
            throw new NullPointerException();
        }
        ParameterPair pair1 = new ParameterPair(a, b);
        ParameterPair pair2 = new ParameterPair(b, a);
        this.paramComparisons.remove(pair2);
        this.paramComparisons.remove(pair1);
        if (comparison != ParamComparison.NC) {
            this.paramComparisons.put(pair1, comparison);
        }
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        buf.append("\nSEM PM:");
        buf.append("\n\tParameters:");
        for (Parameter parameter : this.parameters) {
            buf.append("\n\t\t").append(parameter);
        }
        buf.append("\n\tNodes: ");
        buf.append(this.nodes);
        buf.append("\n\tVariable nodes: ");
        buf.append(this.variableNodes);
        buf.append("\n\tDistributions: ");
        buf.append(this.distributionTypes);
        return buf.toString();
    }

    private void removeDistribution(Node node) {
        this.distributionTypes.remove(node);
        this.parameters.removeAll(this.getDistributionParameters(node));
    }

    private void initializeNodes(SemGraph graph) {
        this.nodes = Collections.unmodifiableList(graph.getNodes());
    }

    private void initializeVariableNodes() {
        ArrayList<Node> varNodes = new ArrayList<Node>();
        for (Node node1 : this.nodes) {
            Node node = node1;
            if (node.getNodeType() != NodeType.MEASURED && node.getNodeType() != NodeType.LATENT) continue;
            varNodes.add(node);
        }
        this.variableNodes = Collections.unmodifiableList(varNodes);
    }

    private void initializeParams() {
        this.parameters = new ArrayList<Parameter>();
        List<Edge> edges = this.graph.getEdges();
        for (Edge edge : edges) {
            if (edge.getNode1() == edge.getNode2()) {
                throw new IllegalStateException("There should not be anyedges from a node to itself in a SemGraph: " + edge);
            }
            if (SemGraph.isErrorEdge(edge) || edge.getEndpoint1() != Endpoint.TAIL || edge.getEndpoint2() != Endpoint.ARROW) continue;
            Parameter param = new Parameter(this.newBName(), ParamType.COEF, edge.getNode1(), edge.getNode2());
            param.setDistribution(new Split(0.5, 1.5));
            this.parameters.add(param);
        }
        this.distributionTypes = new HashMap<Node, DistributionType>();
        this.distributionParameters = new HashMap();
        for (Node node : this.getVariableNodes()) {
            this.addExogenousDistribution(node, DistributionType.NORMAL);
        }
    }

    private void addExogenousDistribution(Node node, DistributionType type) {
        this.distributionTypes.put(node, type);
        int numArgs = DistributionType.NORMAL.getNumArgs();
        LinkedList<Parameter> _parameters = new LinkedList<Parameter>();
        for (int i = 0; i < numArgs; ++i) {
            Parameter parameter = new Parameter(this.newDName(), ParamType.DIST, node, node);
            _parameters.add(parameter);
            this.parameters.add(parameter);
        }
        this.distributionParameters.put(node, _parameters);
    }

    private String newBName() {
        return "B" + ++this.bIndex;
    }

    private String newDName() {
        return "D" + ++this.dIndex;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.graph == null) {
            throw new NullPointerException();
        }
        if (this.nodes == null) {
            throw new NullPointerException();
        }
        if (this.parameters == null) {
            throw new NullPointerException();
        }
        if (this.variableNodes == null) {
            throw new NullPointerException();
        }
        if (this.paramComparisons == null) {
            throw new NullPointerException();
        }
    }
}

