/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.ParamComparison;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.ParameterPair;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetrad.util.dist.Distribution;
import edu.cmu.tetrad.util.dist.Normal;
import edu.cmu.tetrad.util.dist.SingleValue;
import edu.cmu.tetrad.util.dist.Split;
import edu.cmu.tetrad.util.dist.Uniform;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

public final class SemPm
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private SemGraph graph;
    private List<Node> nodes;
    private List<Parameter> parameters;
    private List<Parameter> means;
    private List<Node> variableNodes;
    private List<Node> exogenousNodes;
    private Map<ParameterPair, ParamComparison> paramComparisons = new HashMap<ParameterPair, ParamComparison>();
    private int tIndex = 0;
    private int mIndex = 0;
    private int bIndex = 0;
    private Distribution coefDistribution = new Split(0.5, 1.5);

    public SemPm(Graph graph) {
        this(new SemGraph(graph));
    }

    public SemPm(SemGraph graph) {
        if (graph == null) {
            throw new NullPointerException("Graph must not be null.");
        }
        this.graph = graph;
        this.graph.setShowErrorTerms(false);
        this.initializeNodes(graph);
        this.initializeVariableNodes();
        this.initializeParams();
    }

    public SemPm(SemPm semPm) {
        this.graph = semPm.graph;
        this.nodes = new LinkedList<Node>(semPm.nodes);
        this.parameters = new LinkedList<Parameter>(semPm.parameters);
        this.variableNodes = new LinkedList<Node>(semPm.variableNodes);
        this.paramComparisons = new HashMap<ParameterPair, ParamComparison>(semPm.paramComparisons);
        this.tIndex = semPm.tIndex;
        this.bIndex = semPm.bIndex;
    }

    public static SemPm serializableInstance() {
        Dag dag = new Dag();
        GraphNode node1 = new GraphNode("X");
        dag.addNode(node1);
        return new SemPm(Dag.serializableInstance());
    }

    public SemGraph getGraph() {
        return this.graph;
    }

    public List<Parameter> getParameters() {
        return this.parameters;
    }

    public List<Node> getVariableNodes() {
        return this.variableNodes;
    }

    public List<Node> getErrorNodes() {
        ArrayList<Node> errorNodes = new ArrayList<Node>();
        for (Node node1 : this.nodes) {
            if (node1.getNodeType() != NodeType.ERROR) continue;
            errorNodes.add(node1);
        }
        return errorNodes;
    }

    public List<Node> getMeasuredNodes() {
        ArrayList<Node> measuredNodes = new ArrayList<Node>();
        for (Node variable : this.getVariableNodes()) {
            if (variable.getNodeType() != NodeType.MEASURED) continue;
            measuredNodes.add(variable);
        }
        return measuredNodes;
    }

    public List<Node> getLatentNodes() {
        ArrayList<Node> latentNodes = new ArrayList<Node>();
        for (Node node1 : this.nodes) {
            if (node1.getNodeType() != NodeType.LATENT) continue;
            latentNodes.add(node1);
        }
        return latentNodes;
    }

    public Parameter getParameter(String name) {
        for (Parameter parameter1 : this.getParameters()) {
            if (!name.equals(parameter1.getName())) continue;
            return parameter1;
        }
        return null;
    }

    public Parameter getParameter(Node nodeA, Node nodeB) {
        nodeA = this.getGraph().getVarNode(nodeA);
        nodeB = this.getGraph().getVarNode(nodeB);
        for (Parameter parameter : this.parameters) {
            Node _nodeA = parameter.getNodeA();
            Node _nodeB = parameter.getNodeB();
            if (nodeA != _nodeA || nodeB != _nodeB) continue;
            return parameter;
        }
        return null;
    }

    public Parameter getVarianceParameter(Node node) {
        if (node.getNodeType() == NodeType.ERROR) {
            node = this.getGraph().getChildren(node).get(0);
        }
        for (Parameter parameter : this.parameters) {
            Node _nodeA = parameter.getNodeA();
            Node _nodeB = parameter.getNodeB();
            if (node != _nodeA || node != _nodeB || parameter.getType() != ParamType.VAR) continue;
            return parameter;
        }
        return null;
    }

    public Parameter getCovarianceParameter(Node nodeA, Node nodeB) {
        if (nodeA.getNodeType() == NodeType.ERROR) {
            nodeA = this.getGraph().getChildren(nodeA).get(0);
        }
        if (nodeB.getNodeType() == NodeType.ERROR) {
            nodeB = this.getGraph().getChildren(nodeB).get(0);
        }
        for (Parameter parameter : this.parameters) {
            Node _nodeA = parameter.getNodeA();
            Node _nodeB = parameter.getNodeB();
            if (nodeA == _nodeA && nodeB == _nodeB && parameter.getType() == ParamType.COVAR) {
                return parameter;
            }
            if (nodeB != _nodeA || nodeA != _nodeB || parameter.getType() != ParamType.COVAR) continue;
            return parameter;
        }
        return null;
    }

    public Parameter getCoefficientParameter(Node nodeA, Node nodeB) {
        for (Parameter parameter : this.parameters) {
            Node _nodeA = parameter.getNodeA();
            Node _nodeB = parameter.getNodeB();
            if (nodeA != _nodeA || nodeB != _nodeB || parameter.getType() != ParamType.COEF) continue;
            return parameter;
        }
        return null;
    }

    public Parameter getMeanParameter(Node node) {
        for (Parameter parameter : this.parameters) {
            Node _nodeA = parameter.getNodeA();
            Node _nodeB = parameter.getNodeB();
            if (node != _nodeA || node != _nodeB || parameter.getType() != ParamType.MEAN) continue;
            return parameter;
        }
        return null;
    }

    public String[] getMeasuredVarNames() {
        List<Node> semPmVars = this.getVariableNodes();
        ArrayList<String> varNamesList = new ArrayList<String>();
        for (Node semPmVar : semPmVars) {
            if (semPmVar.getNodeType() != NodeType.MEASURED) continue;
            varNamesList.add(((Object)semPmVar).toString());
        }
        return varNamesList.toArray(new String[0]);
    }

    public ParamComparison getParamComparison(Parameter a, Parameter b) {
        if (a == null || b == null) {
            throw new NullPointerException();
        }
        ParameterPair pair1 = new ParameterPair(a, b);
        ParameterPair pair2 = new ParameterPair(b, a);
        if (this.paramComparisons.containsKey(pair1)) {
            return this.paramComparisons.get(pair1);
        }
        if (this.paramComparisons.containsKey(pair2)) {
            return this.paramComparisons.get(pair2);
        }
        return ParamComparison.NC;
    }

    public void setParamComparison(Parameter a, Parameter b, ParamComparison comparison) {
        if (a == null || b == null || comparison == null) {
            throw new NullPointerException();
        }
        ParameterPair pair1 = new ParameterPair(a, b);
        ParameterPair pair2 = new ParameterPair(b, a);
        this.paramComparisons.remove(pair2);
        this.paramComparisons.remove(pair1);
        if (comparison != ParamComparison.NC) {
            this.paramComparisons.put(pair1, comparison);
        }
    }

    public List<Parameter> getFreeParameters() {
        List<Parameter> parameters = this.getParameters();
        ArrayList<Parameter> freeParameters = new ArrayList<Parameter>();
        for (Parameter _parameter : parameters) {
            ParamType type = _parameter.getType();
            if (type != ParamType.VAR && type != ParamType.COVAR && type != ParamType.COEF || _parameter.isFixed()) continue;
            freeParameters.add(_parameter);
        }
        return freeParameters;
    }

    public int getDof() {
        return this.getMeasuredNodes().size() * (this.getMeasuredNodes().size() + 1) / 2 - this.getFreeParameters().size();
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        buf.append("\nParameters:\n");
        for (Parameter parameter : this.parameters) {
            buf.append("\n").append(parameter);
        }
        buf.append("\n\nNodes: ");
        buf.append(this.nodes);
        buf.append("\n\nVariable nodes: ");
        buf.append(this.variableNodes);
        return buf.toString();
    }

    private void initializeNodes(SemGraph graph) {
        this.nodes = Collections.unmodifiableList(graph.getNodes());
    }

    private void initializeVariableNodes() {
        ArrayList<Node> varNodes = new ArrayList<Node>();
        for (Node node1 : this.nodes) {
            Node node = node1;
            if (node.getNodeType() != NodeType.MEASURED && node.getNodeType() != NodeType.LATENT) continue;
            varNodes.add(node);
        }
        this.variableNodes = Collections.unmodifiableList(varNodes);
    }

    private void initializeParams() {
        Parameter param;
        ArrayList<Parameter> parameters = new ArrayList<Parameter>();
        ArrayList<Parameter> means = new ArrayList<Parameter>();
        List<Edge> edges = this.graph.getEdges();
        Collections.sort(edges, new Comparator<Edge>(){

            @Override
            public int compare(Edge o1, Edge o2) {
                int compareFirst = o1.getNode1().getName().compareTo(((Object)o2.getNode1()).toString());
                int compareSecond = o1.getNode1().getName().compareTo(((Object)o2.getNode2()).toString());
                if (compareFirst != 0) {
                    return compareFirst;
                }
                if (compareSecond != 0) {
                    return compareSecond;
                }
                return 0;
            }
        });
        for (Edge edge : edges) {
            if (edge.getNode1() == edge.getNode2()) {
                throw new IllegalStateException("There should not be anyedges from a node to itself in a SemGraph: " + edge);
            }
            if (SemGraph.isErrorEdge(edge) || edge.getEndpoint1() != Endpoint.TAIL || edge.getEndpoint2() != Endpoint.ARROW) continue;
            param = new Parameter(this.newBName(), ParamType.COEF, edge.getNode1(), edge.getNode2());
            param.setDistribution(new Split(0.5, 1.5));
            parameters.add(param);
        }
        for (Node node : this.getVariableNodes()) {
            param = new Parameter(this.newTName(), ParamType.VAR, node, node);
            param.setDistribution(new Uniform(1.0, 3.0));
            parameters.add(param);
        }
        for (Edge edge : edges) {
            if (!Edges.isBidirectedEdge(edge)) continue;
            Node node1 = edge.getNode1();
            Node node2 = edge.getNode2();
            node1 = this.getGraph().getVarNode(node1);
            node2 = this.getGraph().getVarNode(node2);
            Parameter param2 = new Parameter(this.newTName(), ParamType.COVAR, node1, node2);
            param2.setDistribution(new SingleValue(0.2));
            parameters.add(param2);
        }
        for (Node node : this.getVariableNodes()) {
            Parameter mean = new Parameter(this.newMName(), ParamType.MEAN, node, node);
            mean.setDistribution(new Normal(0.0, 1.0));
            parameters.add(mean);
        }
        this.parameters = Collections.unmodifiableList(parameters);
        this.means = means;
    }

    private String newTName() {
        return "T" + ++this.tIndex;
    }

    private String newMName() {
        return "M" + ++this.mIndex;
    }

    private String newBName() {
        return "B" + ++this.bIndex;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.graph == null) {
            throw new NullPointerException();
        }
        if (this.nodes == null) {
            throw new NullPointerException();
        }
        if (this.parameters == null) {
            throw new NullPointerException();
        }
        if (this.variableNodes == null) {
            throw new NullPointerException();
        }
        if (this.paramComparisons == null) {
            throw new NullPointerException();
        }
        if (this.tIndex < 0) {
            throw new IllegalStateException("TIndex out of range: " + this.tIndex);
        }
        if (this.mIndex < 0) {
            throw new IllegalStateException("MIndex out of range: " + this.mIndex);
        }
        if (this.bIndex < 0) {
            throw new IllegalStateException("BIndex out of range: " + this.bIndex);
        }
    }

    public void setCoefDistribution(Distribution distribution) {
        this.coefDistribution = distribution;
    }
}

