/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.SparseDoubleMatrix2D;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.DistributionType;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.Sem2DistributionMapping;
import edu.cmu.tetrad.sem.Sem2Mapping;
import edu.cmu.tetrad.sem.Sem2MatrixMapping;
import edu.cmu.tetrad.sem.SemPm2;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetrad.util.dist.Beta;
import edu.cmu.tetrad.util.dist.Distribution;
import edu.cmu.tetrad.util.dist.GaussianPower;
import edu.cmu.tetrad.util.dist.Normal;
import edu.cmu.tetrad.util.dist.SingleValue;
import edu.cmu.tetrad.util.dist.Split;
import edu.cmu.tetrad.util.dist.Uniform;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.rmi.MarshalledObject;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

public final class SemIm2
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private final SemPm2 semPm2;
    private final List<Node> variableNodes;
    private final List<Node> measuredNodes;
    private List<Parameter> freeParameters;
    private List<Parameter> fixedParameters;
    private DoubleMatrix2D edgeCoef;
    private Map<Node, Distribution> distributions;
    private int sampleSize;
    private List<Sem2Mapping> freeMappings;
    private List<Sem2Mapping> fixedMappings;
    private double[] standardErrors;
    private boolean parameterBoundsEnforced = true;
    private boolean estimated = false;
    private transient Distribution coefDistribution;
    private transient Distribution varDistribution;
    private transient Distribution covarDistribution;

    public SemIm2(SemPm2 semPm2) {
        if (semPm2 == null) {
            throw new NullPointerException("Sem PM must not be null.");
        }
        this.semPm2 = new SemPm2(semPm2);
        this.variableNodes = Collections.unmodifiableList(semPm2.getVariableNodes());
        this.measuredNodes = Collections.unmodifiableList(semPm2.getMeasuredNodes());
        int numVars = this.variableNodes.size();
        this.edgeCoef = new SparseDoubleMatrix2D(numVars, numVars);
        this.freeParameters = new ArrayList<Parameter>();
        this.freeMappings = new ArrayList<Sem2Mapping>();
        this.fixedParameters = new ArrayList<Parameter>();
        this.fixedMappings = new ArrayList<Sem2Mapping>();
        this.distributions = new HashMap<Node, Distribution>();
        this.processCoefficientParameters();
        this.processDistributionParameters();
        this.initializeCoefValues();
    }

    public SemIm2(SemIm2 semIm) {
        try {
            SemIm2 _semIm = new MarshalledObject<SemIm2>(semIm).get();
            this.semPm2 = _semIm.semPm2;
            this.variableNodes = _semIm.variableNodes;
            this.measuredNodes = _semIm.measuredNodes;
            this.freeParameters = _semIm.freeParameters;
            this.fixedParameters = _semIm.fixedParameters;
            this.edgeCoef = _semIm.edgeCoef.copy();
            this.sampleSize = _semIm.sampleSize;
            this.freeMappings = _semIm.freeMappings;
            this.fixedMappings = _semIm.fixedMappings;
            this.standardErrors = _semIm.standardErrors;
            this.parameterBoundsEnforced = _semIm.parameterBoundsEnforced;
            this.estimated = _semIm.estimated;
        }
        catch (IOException e) {
            throw new RuntimeException("SemIm could not be deep cloned.", e);
        }
        catch (ClassNotFoundException e) {
            throw new RuntimeException("SemIm could not be deep cloned.", e);
        }
    }

    public static SemIm2 retainValues(SemIm2 semIm, SemGraph graph) {
        SemPm2 newSemPm2 = new SemPm2(graph);
        SemIm2 newSemIm = new SemIm2(newSemPm2);
        for (Parameter p1 : newSemIm.getSemPm2().getParameters()) {
            Node nodeA = semIm.getSemPm2().getGraph().getNode(p1.getNodeA().getName());
            Node nodeB = semIm.getSemPm2().getGraph().getNode(p1.getNodeB().getName());
            for (Parameter p2 : semIm.getSemPm2().getParameters()) {
                if (p2.getNodeA() != nodeA || p2.getNodeB() != nodeB || p2.getType() != p1.getType()) continue;
                newSemIm.setParamValue(p1, semIm.getParamValue(p2));
            }
        }
        newSemIm.sampleSize = semIm.sampleSize;
        return newSemIm;
    }

    public static SemIm2 serializableInstance() {
        return new SemIm2(SemPm2.serializableInstance());
    }

    public SemPm2 getSemPm2() {
        return this.semPm2;
    }

    public double[] getFreeParamValues() {
        double[] paramValues = new double[this.freeMappings().size()];
        for (int i = 0; i < this.freeMappings().size(); ++i) {
            Sem2Mapping mapping = this.freeMappings().get(i);
            paramValues[i] = mapping.getValue();
        }
        return paramValues;
    }

    public void setFreeParamValues(double[] params) {
        if (params.length != this.getNumFreeParams()) {
            throw new IllegalArgumentException("The array provided must be of the same length as the number of free parameters.");
        }
        for (int i = 0; i < this.freeMappings().size(); ++i) {
            Sem2Mapping mapping = this.freeMappings().get(i);
            mapping.setValue(params[i]);
        }
    }

    public double getParamValue(Parameter parameter) {
        if (parameter == null) {
            throw new NullPointerException();
        }
        if (this.getFreeParameters().contains(parameter)) {
            int index = this.getFreeParameters().indexOf(parameter);
            Sem2Mapping mapping = this.freeMappings.get(index);
            return mapping.getValue();
        }
        if (this.getFixedParameters().contains(parameter)) {
            int index = this.getFixedParameters().indexOf(parameter);
            Sem2Mapping mapping = this.fixedMappings.get(index);
            return mapping.getValue();
        }
        throw new IllegalArgumentException("Not a parameter in this model: " + parameter);
    }

    public void setParamValue(Parameter parameter, double value) {
        if (!this.getFreeParameters().contains(parameter)) {
            throw new IllegalArgumentException("That parameter cannot be set in this model: " + parameter);
        }
        int index = this.getFreeParameters().indexOf(parameter);
        Sem2Mapping mapping = this.freeMappings.get(index);
        mapping.setValue(value);
    }

    public void setFixedParamValue(Parameter parameter, double value) {
        if (!this.getFixedParameters().contains(parameter)) {
            throw new IllegalArgumentException("Not a fixed parameter in this model: " + parameter);
        }
        int index = this.getFixedParameters().indexOf(parameter);
        Sem2Mapping mapping = this.fixedMappings.get(index);
        mapping.setValue(value);
    }

    public double getCoef(Node x, Node y) {
        Parameter parameter = this.getSemPm2().getCoefficientParameter(x, y);
        return this.getParamValue(parameter);
    }

    public void setCoef(Node x, Node y, double value) {
        this.setParamValue(x, y, value);
    }

    public Distribution getDistribution(Node node) {
        return this.distributions.get(node);
    }

    public void setParamValue(Node nodeA, Node nodeB, double value) {
        Parameter parameter = null;
        if (nodeA == nodeB) {
            parameter = this.getSemPm2().getVarianceParameter(nodeA);
        }
        if (parameter == null) {
            parameter = this.getSemPm2().getCoefficientParameter(nodeA, nodeB);
        }
        if (parameter == null) {
            parameter = this.getSemPm2().getCovarianceParameter(nodeA, nodeB);
        }
        if (parameter == null) {
            throw new IllegalArgumentException("There is no parameter in model for an edge from " + nodeA + " to " + nodeB + ".");
        }
        if (!this.getFreeParameters().contains(parameter)) {
            throw new IllegalArgumentException("Not a free parameter in this model: " + parameter);
        }
        this.setParamValue(parameter, value);
    }

    public List<Parameter> getFreeParameters() {
        return this.freeParameters;
    }

    public int getNumFreeParams() {
        return this.getFreeParameters().size();
    }

    public List<Parameter> getFixedParameters() {
        return this.fixedParameters;
    }

    public int getNumFixedParams() {
        return this.getFixedParameters().size();
    }

    public List<Node> getVariableNodes() {
        return this.variableNodes;
    }

    public List<Node> getMeasuredNodes() {
        return this.measuredNodes;
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public DoubleMatrix2D getCoefMatrix() {
        return this.edgeCoef().copy();
    }

    private void initializeCoefValues() {
        Parameter parameter;
        for (Sem2Mapping fixedMapping : this.fixedMappings) {
            parameter = fixedMapping.getParameter();
            if (parameter.getType() != ParamType.COEF) continue;
            fixedMapping.setValue(this.initialValue(parameter));
        }
        for (Sem2Mapping freeMapping : this.freeMappings) {
            parameter = freeMapping.getParameter();
            if (parameter.getType() != ParamType.COEF) continue;
            freeMapping.setValue(this.initialValue(parameter));
        }
    }

    public boolean isParameterBoundsEnforced() {
        return this.parameterBoundsEnforced;
    }

    public void setParameterBoundsEnforced(boolean parameterBoundsEnforced) {
        this.parameterBoundsEnforced = parameterBoundsEnforced;
    }

    public boolean isEstimated() {
        return this.estimated;
    }

    public void setEstimated(boolean estimated) {
        this.estimated = estimated;
    }

    public DataSet simulateData(int sampleSize, Distribution distribution) {
        LinkedList<Node> variables = new LinkedList<Node>();
        List<Node> variableNodes = this.getSemPm2().getVariableNodes();
        for (Node node : variableNodes) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            variables.add(var);
        }
        ColtDataSet dataSet = new ColtDataSet(sampleSize, variables);
        SemGraph graph = this.getSemPm2().getGraph();
        List<Node> tierOrdering = graph.getTierOrdering();
        int[] tierIndices = new int[variableNodes.size()];
        for (int i = 0; i < tierIndices.length; ++i) {
            tierIndices[i] = variableNodes.indexOf(tierOrdering.get(i));
        }
        int[][] _parents = new int[variables.size()][];
        for (int i = 0; i < variableNodes.size(); ++i) {
            Node node = variableNodes.get(i);
            List<Node> parents = graph.getParents(node);
            Iterator<Node> j = parents.iterator();
            while (j.hasNext()) {
                Node _node = j.next();
                if (_node.getNodeType() != NodeType.ERROR) continue;
                j.remove();
            }
            _parents[i] = new int[parents.size()];
            for (int j2 = 0; j2 < parents.size(); ++j2) {
                Node _parent = parents.get(j2);
                _parents[i][j2] = variableNodes.indexOf(_parent);
            }
        }
        for (int row = 0; row < sampleSize; ++row) {
            for (int i = 0; i < tierOrdering.size(); ++i) {
                int col = tierIndices[i];
                double value = distribution.nextRandom();
                for (int j = 0; j < _parents[col].length; ++j) {
                    int parent = _parents[col][j];
                    value += dataSet.getDouble(row, parent) * this.getCoefMatrix().get(parent, col);
                }
                dataSet.setDouble(row, col, value);
            }
        }
        return dataSet;
    }

    public String toString() {
        Sem2Mapping iMapping;
        int i;
        StringBuilder buf = new StringBuilder();
        buf.append("\nSem");
        buf.append("\n\n\tVariable nodes:\n");
        buf.append("\t");
        buf.append(this.getVariableNodes());
        buf.append("\n\n\tmeasuredNodes:\n");
        buf.append("\t");
        buf.append(this.getMeasuredNodes());
        buf.append("\n\n\tedgeCoef:\n");
        buf.append(MatrixUtils.toString(this.edgeCoef().toArray()));
        buf.append("\n\n\tsampleSize = ");
        buf.append("\t");
        buf.append(this.sampleSize);
        buf.append("\n\n\tfree mappings:\n");
        for (i = 0; i < this.freeMappings.size(); ++i) {
            iMapping = this.freeMappings.get(i);
            buf.append("\n\t");
            buf.append(i);
            buf.append(". ");
            buf.append(iMapping);
        }
        buf.append("\n\n\tfixed mappings:\n");
        for (i = 0; i < this.fixedMappings.size(); ++i) {
            iMapping = this.fixedMappings.get(i);
            buf.append("\n\t");
            buf.append(i);
            buf.append(". ");
            buf.append(iMapping);
        }
        return buf.toString();
    }

    private void processCoefficientParameters() {
        SemGraph graph = this.getSemPm2().getGraph();
        for (Parameter parameter : this.getSemPm2().getParameters()) {
            if (parameter.getType() != ParamType.COEF) continue;
            Node nodeA = graph.getVarNode(parameter.getNodeA());
            Node nodeB = graph.getVarNode(parameter.getNodeB());
            int i = this.getVariableNodes().indexOf(nodeA);
            int j = this.getVariableNodes().indexOf(nodeB);
            Sem2MatrixMapping mapping = new Sem2MatrixMapping(this, parameter, this.edgeCoef(), i, j);
            if (parameter.isFixed()) {
                this.fixedParameters.add(parameter);
                this.fixedMappings.add(mapping);
                continue;
            }
            this.freeParameters.add(parameter);
            this.freeMappings.add(mapping);
        }
    }

    private void processDistributionParameters() {
        SemGraph graph = this.getSemPm2().getGraph();
        for (Node node : graph.getNodes()) {
            if (node.getNodeType() == NodeType.ERROR) continue;
            DistributionType _type = this.semPm2.getDistributionType(node);
            Distribution distribution = this.getDefaultDistribution(_type);
            this.distributions.put(node, distribution);
            List<Parameter> _parameters = this.semPm2.getDistributionParameters(node);
            for (int i = 0; i < _parameters.size(); ++i) {
                Sem2DistributionMapping mapping = new Sem2DistributionMapping(distribution, i, _parameters.get(i));
                this.freeParameters.add(_parameters.get(i));
                this.freeMappings.add(mapping);
            }
        }
    }

    private double initialValue(Parameter parameter) {
        if (this.coefDistribution == null) {
            this.coefDistribution = new Split(0.5, 1.5);
        }
        if (this.varDistribution == null) {
            this.varDistribution = new Uniform(1.0, 3.0);
        }
        if (this.covarDistribution == null) {
            this.covarDistribution = new SingleValue(0.2);
        }
        if (parameter.isInitializedRandomly()) {
            if (parameter.getType() == ParamType.COEF) {
                return this.coefDistribution.nextRandom();
            }
            if (parameter.getType() == ParamType.VAR) {
                return this.varDistribution.nextRandom();
            }
            return this.covarDistribution.nextRandom();
        }
        return parameter.getStartingValue();
    }

    private List<Sem2Mapping> freeMappings() {
        return this.freeMappings;
    }

    private Distribution getDefaultDistribution(DistributionType distributionType) {
        if (distributionType == DistributionType.NORMAL) {
            return new Normal(0.0, 1.0);
        }
        if (distributionType == DistributionType.UNIFORM) {
            return new Uniform(-1.0, 1.0);
        }
        if (distributionType == DistributionType.BETA) {
            return new Beta(0.2, 0.3);
        }
        if (distributionType == DistributionType.GAUSSIAN_POWER) {
            return new GaussianPower(2.0);
        }
        if (distributionType == DistributionType.GAUSSIAN_POWER) {
            return new GaussianPower(2.0);
        }
        throw new IllegalArgumentException();
    }

    private DoubleMatrix2D edgeCoef() {
        return this.edgeCoef;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.semPm2 == null) {
            throw new NullPointerException();
        }
        if (this.variableNodes == null) {
            throw new NullPointerException();
        }
        if (this.measuredNodes == null) {
            throw new NullPointerException();
        }
        if (this.freeParameters == null) {
            throw new NullPointerException();
        }
        if (this.freeMappings == null) {
            throw new NullPointerException();
        }
        if (this.fixedParameters == null) {
            throw new NullPointerException();
        }
        if (this.fixedMappings == null) {
            throw new NullPointerException();
        }
        if (this.sampleSize < 0) {
            throw new IllegalArgumentException("Sample size out of range: " + this.sampleSize);
        }
    }
}

