/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.AbstractVariable;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.TimeLagGraph;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.BDeuScore;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.search.SemBicScore;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.apache.commons.math3.exception.MaxCountExceededException;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.util.FastMath;

public class TimeSeriesUtils {
    public static DataSet ar(DataSet timeSeries, int numLags) {
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, numLags);
        ArrayList<Node> regressors = new ArrayList<Node>();
        for (int i = timeSeries.getNumColumns(); i < timeLags.getNumColumns(); ++i) {
            regressors.add(timeLags.getVariable(i));
        }
        RegressionDataset regression = new RegressionDataset(timeLags);
        Matrix residuals = new Matrix(timeLags.getNumRows(), timeSeries.getNumColumns());
        for (int i = 0; i < timeSeries.getNumColumns(); ++i) {
            Node target = timeLags.getVariable(i);
            RegressionResult result = regression.regress(target, regressors);
            Vector residualsColumn = result.getResiduals();
            residuals.assignColumn(i, residualsColumn);
        }
        return new BoxDataSet(new DoubleDataBox(residuals.toArray()), timeSeries.getVariables());
    }

    public static DataSet ar2(DataSet timeSeries, int numLags) {
        ArrayList<Node> missingVariables = new ArrayList<Node>();
        for (Node node : timeSeries.getVariables()) {
            int index = timeSeries.getVariables().indexOf(node);
            boolean missing = true;
            for (int i = 0; i < timeSeries.getNumRows(); ++i) {
                if (Double.isNaN(timeSeries.getDouble(i, index))) continue;
                missing = false;
                break;
            }
            if (!missing) continue;
            missingVariables.add(node);
        }
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, numLags);
        RegressionDataset regression = new RegressionDataset(timeLags);
        Matrix residuals = new Matrix(timeLags.getNumRows(), timeSeries.getNumColumns());
        for (int i = 0; i < timeSeries.getNumColumns(); ++i) {
            Node target = timeLags.getVariable(i);
            int index = timeSeries.getVariables().indexOf(target);
            if (missingVariables.contains(target)) {
                for (int i2 = 0; i2 < residuals.rows(); ++i2) {
                    residuals.set(i2, index, Double.NaN);
                }
                continue;
            }
            ArrayList<Node> regressors = new ArrayList<Node>();
            for (int i2 = timeSeries.getNumColumns(); i2 < timeLags.getNumColumns(); ++i2) {
                int varIndex = i2 % timeSeries.getNumColumns();
                Node var = timeSeries.getVariable(varIndex);
                if (missingVariables.contains(var)) continue;
                regressors.add(timeLags.getVariable(i2));
            }
            RegressionResult result = regression.regress(target, regressors);
            Vector residualsColumn = result.getResiduals();
            residuals.assignColumn(i, residualsColumn);
        }
        return new BoxDataSet(new DoubleDataBox(residuals.toArray()), timeSeries.getVariables());
    }

    private int[] eliminateMissing(int[] parents, int dataIndex, DataSet dataSet, List<Node> missingVariables) {
        ArrayList<Integer> _parents = new ArrayList<Integer>();
        for (int k : parents) {
            if (missingVariables.contains(dataSet.getVariable(k))) continue;
            _parents.add(k);
        }
        int[] _parents2 = new int[_parents.size()];
        for (int i = 0; i < _parents.size(); ++i) {
            _parents2[i] = (Integer)_parents.get(i);
        }
        return _parents2;
    }

    public static VarResult structuralVar(DataSet timeSeries, int numLags) {
        Score score;
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, numLags);
        Knowledge knowledge = timeLags.getKnowledge().copy();
        for (int i = 0; i <= numLags; ++i) {
            knowledge.setTierForbiddenWithin(i, true);
        }
        if (timeLags.isDiscrete()) {
            score = new BDeuScore(timeLags);
        } else if (timeLags.isContinuous()) {
            SemBicScore semBicScore = new SemBicScore(new CovarianceMatrix(timeLags));
            semBicScore.setPenaltyDiscount(2.0);
            score = semBicScore;
        } else {
            throw new IllegalArgumentException("Mixed data set");
        }
        Fges search = new Fges(score);
        search.setKnowledge(knowledge);
        Graph graph = search.search();
        EdgeListGraph collapsedVarGraph = new EdgeListGraph(timeSeries.getVariables());
        for (Edge edge : graph.getEdges()) {
            Node node2;
            String node1_before = edge.getNode1().getName();
            String node2_before = edge.getNode2().getName();
            String node1_after = node1_before.substring(0, node1_before.indexOf("."));
            String node2_after = node2_before.substring(0, node2_before.indexOf("."));
            Node node1 = collapsedVarGraph.getNode(node1_after);
            Edge _edge = new Edge(node1, node2 = collapsedVarGraph.getNode(node2_after), edge.getEndpoint1(), edge.getEndpoint2());
            if (collapsedVarGraph.containsEdge(_edge)) continue;
            collapsedVarGraph.addEdge(_edge);
        }
        Matrix residuals = new Matrix(timeLags.getNumRows(), timeSeries.getNumColumns());
        RegressionDataset regression = new RegressionDataset(timeLags);
        for (int i = 0; i < timeSeries.getNumColumns(); ++i) {
            Node target = timeLags.getVariable(i);
            ArrayList<Node> regressors = new ArrayList<Node>();
            for (int j = 0; j <= 0; ++j) {
                Node variable = timeLags.getVariable(i + j * timeSeries.getNumColumns());
                regressors.addAll(graph.getParents(variable));
            }
            RegressionResult result = regression.regress(target, regressors);
            Vector residualsColumn = result.getResiduals();
            residuals.assignColumn(i, residualsColumn);
        }
        return new VarResult(new BoxDataSet(new DoubleDataBox(residuals.toArray()), timeSeries.getVariables()), collapsedVarGraph);
    }

    public static DataSet createShiftedData(DataSet data, int[] shifts) {
        Matrix data2 = data.getDoubleData();
        int min = Integer.MAX_VALUE;
        int max = Integer.MIN_VALUE;
        for (int shift : shifts) {
            if (shift < min) {
                min = shift;
            }
            if (shift <= max) continue;
            max = shift;
        }
        int shiftRange = max - min;
        int[] _shifts = new int[shifts.length];
        for (int i = 0; i < shifts.length; ++i) {
            _shifts[i] = shiftRange - (shifts[i] - min);
        }
        if (shiftRange > data2.rows()) {
            throw new IllegalArgumentException("Range of shifts greater than sample size.");
        }
        int shiftedDataLength = data2.rows() - shiftRange;
        Matrix shiftedData = new Matrix(shiftedDataLength, data2.columns());
        for (int j = 0; j < shiftedData.columns(); ++j) {
            for (int i = 0; i < shiftedDataLength; ++i) {
                shiftedData.set(i, j, data2.get(i + _shifts[j], j));
            }
        }
        return new BoxDataSet(new DoubleDataBox(shiftedData.toArray()), data.getVariables());
    }

    public static double[] getSelfLoopCoefs(DataSet timeSeries) {
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, 1);
        double[] coefs = new double[timeSeries.getNumColumns()];
        for (int j = 0; j < timeSeries.getNumColumns(); ++j) {
            Node target = timeLags.getVariable(j);
            Node selfLoop = timeLags.getVariable(j + timeSeries.getNumColumns());
            List<Node> regressors = Collections.singletonList(selfLoop);
            RegressionDataset regression = new RegressionDataset(timeLags);
            RegressionResult result = regression.regress(target, regressors);
            coefs[j] = result.getCoef()[1];
        }
        return coefs;
    }

    public static double sumOfArCoefficients(DataSet timeSeries, int numLags) {
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, numLags);
        ArrayList<Node> regressors = new ArrayList<Node>();
        for (int i = timeSeries.getNumColumns(); i < timeLags.getNumColumns(); ++i) {
            regressors.add(timeLags.getVariable(i));
        }
        RegressionDataset regression = new RegressionDataset(timeLags);
        Matrix residuals = new Matrix(timeLags.getNumRows(), timeSeries.getNumColumns());
        double sum = 0.0;
        int n = 0;
        for (int i = 0; i < timeSeries.getNumColumns(); ++i) {
            double[] coef;
            Node target = timeLags.getVariable(i);
            RegressionResult result = regression.regress(target, regressors);
            for (double v : coef = result.getCoef()) {
                sum += v * v;
                ++n;
            }
            Vector residualsColumn = result.getResiduals();
            residuals.assignColumn(i, residualsColumn);
        }
        return sum / (double)n;
    }

    public static DataSet difference(DataSet data, int d) {
        if (d == 0) {
            return data;
        }
        Matrix _data = data.getDoubleData();
        for (int k = 1; k <= d; ++k) {
            Matrix _data2 = new Matrix(_data.rows() - 1, _data.columns());
            for (int i = 1; i < _data.rows(); ++i) {
                for (int j = 0; j < _data.columns(); ++j) {
                    _data2.set(i - 1, j, _data.get(i, j) - _data.get(i - 1, j));
                }
            }
            _data = _data2;
        }
        return new BoxDataSet(new DoubleDataBox(_data.toArray()), data.getVariables());
    }

    public static DataSet createLagData(DataSet data, int numLags) {
        List<Node> variables = data.getVariables();
        int dataSize = variables.size();
        int laggedRows = data.getNumRows() - numLags;
        Knowledge knowledge = new Knowledge();
        Node[][] laggedNodes = new Node[numLags + 1][dataSize];
        ArrayList<Node> newVariables = new ArrayList<Node>((numLags + 1) * dataSize + 1);
        for (int lag = 0; lag <= numLags; ++lag) {
            for (int col = 0; col < dataSize; ++col) {
                AbstractVariable laggedNode;
                String varName;
                Node node = variables.get(col);
                String name = varName = node.getName();
                if (lag != 0) {
                    name = name + ":" + lag;
                }
                if (node instanceof ContinuousVariable) {
                    laggedNode = new ContinuousVariable(name);
                } else if (node instanceof DiscreteVariable) {
                    DiscreteVariable var = (DiscreteVariable)node;
                    laggedNode = new DiscreteVariable(var);
                    laggedNode.setName(name);
                } else {
                    throw new IllegalStateException("Node must be either continuous or discrete");
                }
                newVariables.add(laggedNode);
                laggedNode.setCenter(80 * col + 50, 80 * (numLags - lag) + 50);
                laggedNodes[lag][col] = laggedNode;
            }
        }
        try {
            for (Node node : newVariables) {
                int lag;
                String varName = node.getName();
                if (varName.indexOf(58) == -1) {
                    lag = 0;
                } else {
                    String tmp = varName.substring(varName.indexOf(58) + 1);
                    lag = Integer.parseInt(tmp);
                }
                knowledge.addToTier(numLags - lag, node.getName());
            }
        }
        catch (NumberFormatException e) {
            return data;
        }
        BoxDataSet laggedData = new BoxDataSet(new DoubleDataBox(laggedRows, newVariables.size()), newVariables);
        for (int lag = 0; lag <= numLags; ++lag) {
            for (int col = 0; col < dataSize; ++col) {
                for (int row = 0; row < laggedRows; ++row) {
                    Node laggedNode = laggedNodes[lag][col];
                    if (laggedNode instanceof ContinuousVariable) {
                        double value = data.getDouble(row + numLags - lag, col);
                        laggedData.setDouble(row, col + lag * dataSize, value);
                        continue;
                    }
                    int value = data.getInt(row + numLags - lag, col);
                    laggedData.setInt(row, col + lag * dataSize, value);
                }
            }
        }
        knowledge.setDefaultToKnowledgeLayout(true);
        laggedData.setKnowledge(knowledge);
        return laggedData;
    }

    public static DataSet addIndex(DataSet data) {
        data = data.copy();
        ContinuousVariable timeVar = new ContinuousVariable("Time");
        data.addVariable(timeVar);
        int c = data.getColumn(timeVar);
        for (int r = 0; r < data.getNumRows(); ++r) {
            data.setDouble(r, c, r + 1);
        }
        return data;
    }

    public static TimeLagGraph graphToLagGraph(Graph _graph, int numLags) {
        Node from;
        TimeLagGraph graph = new TimeLagGraph();
        graph.setMaxLag(numLags);
        for (Node node : _graph.getNodes()) {
            ContinuousVariable graphNode = new ContinuousVariable(node.getName());
            graphNode.setNodeType(node.getNodeType());
            graph.addNode(graphNode);
            from = graph.getNode(node.getName(), 1);
            Node to = graph.getNode(node.getName(), 0);
            Edge edge = new Edge(from, to, Endpoint.TAIL, Endpoint.ARROW);
            graph.addEdge(edge);
        }
        for (Edge edge : _graph.getEdges()) {
            if (!Edges.isDirectedEdge(edge)) {
                throw new IllegalArgumentException();
            }
            Node from2 = edge.getNode1();
            Node to = edge.getNode2();
            Node _from = graph.getNode(from2.getName(), 0);
            Node _to = graph.getNode(to.getName(), 0);
            Edge edge1 = new Edge(_from, _to, Endpoint.TAIL, Endpoint.ARROW);
            graph.addEdge(edge1);
        }
        for (int lag = 1; lag <= numLags; ++lag) {
            for (Node node1 : graph.getLag0Nodes()) {
                from = graph.getNode(node1.getName(), lag);
                for (Node node2 : graph.getLag0Nodes()) {
                    Node to = graph.getNode(node2.getName(), 0);
                    if (node1.getName().equals(node2.getName()) || !(RandomUtil.getInstance().nextUniform(0.0, 1.0) <= 0.15)) continue;
                    Edge edge = new Edge(from, to, Endpoint.TAIL, Endpoint.ARROW);
                    graph.addEdge(edge);
                }
            }
        }
        return graph;
    }

    public static String getNameNoLag(Object obj) {
        String tempS = obj.toString();
        if (tempS.indexOf(58) == -1) {
            return tempS;
        }
        return tempS.substring(0, tempS.indexOf(58));
    }

    public static String getPrefix(String s) {
        return s.substring(0, 1);
    }

    public static int getIndex(String s) {
        int y = 0;
        for (int i = s.length() - 1; i >= 0; --i) {
            try {
                y = Integer.parseInt(s.substring(i));
                continue;
            }
            catch (NumberFormatException e) {
                return y;
            }
        }
        throw new IllegalArgumentException("Not integer suffix.");
    }

    public static int getLag(String s) {
        if (s.indexOf(58) == -1) {
            return 0;
        }
        String tmp = s.substring(s.indexOf(58) + 1);
        return Integer.parseInt(tmp);
    }

    public static Knowledge getKnowledge(Graph graph) {
        String tmp;
        int lag;
        String varName;
        int numLags = 1;
        List<Node> variables = graph.getNodes();
        ArrayList<Integer> laglist = new ArrayList<Integer>();
        Knowledge knowledge = new Knowledge();
        for (Node node : variables) {
            varName = node.getName();
            if (varName.indexOf(58) == -1) {
                lag = 0;
            } else {
                tmp = varName.substring(varName.indexOf(58) + 1);
                lag = Integer.parseInt(tmp);
            }
            laglist.add(lag);
        }
        numLags = (Integer)Collections.max(laglist);
        Collections.sort(variables, new Comparator<Node>(){

            @Override
            public int compare(Node o1, Node o2) {
                String name1 = TimeSeriesUtils.getNameNoLag(o1);
                String name2 = TimeSeriesUtils.getNameNoLag(o2);
                String prefix1 = TimeSeriesUtils.getPrefix(name1);
                String prefix2 = TimeSeriesUtils.getPrefix(name2);
                int index1 = TimeSeriesUtils.getIndex(name1);
                int index2 = TimeSeriesUtils.getIndex(name2);
                if (TimeSeriesUtils.getLag(o1.getName()) == TimeSeriesUtils.getLag(o2.getName())) {
                    if (prefix1.compareTo(prefix2) == 0) {
                        return Integer.compare(index1, index2);
                    }
                    return prefix1.compareTo(prefix2);
                }
                return TimeSeriesUtils.getLag(o1.getName()) - TimeSeriesUtils.getLag(o2.getName());
            }
        });
        for (Node node : variables) {
            varName = node.getName();
            if (varName.indexOf(58) == -1) {
                lag = 0;
            } else {
                tmp = varName.substring(varName.indexOf(58) + 1);
                lag = Integer.parseInt(tmp);
            }
            knowledge.addToTier(numLags - lag, node.getName());
        }
        return knowledge;
    }

    public static boolean allEigenvaluesAreSmallerThanOneInModulus(Matrix mat) {
        double[] realEigenvalues = new double[]{};
        double[] imagEigenvalues = new double[]{};
        try {
            EigenDecomposition dec = new EigenDecomposition(new BlockRealMatrix(mat.toArray()));
            realEigenvalues = dec.getRealEigenvalues();
            imagEigenvalues = dec.getImagEigenvalues();
        }
        catch (MaxCountExceededException e) {
            e.printStackTrace();
        }
        for (int i = 0; i < realEigenvalues.length; ++i) {
            double realEigenvalue = realEigenvalues[i];
            double imagEigenvalue = imagEigenvalues[i];
            System.out.println("Real eigenvalues are : " + realEigenvalue + " and imag part : " + imagEigenvalue);
            double modulus = FastMath.sqrt(FastMath.pow(realEigenvalue, 2) + FastMath.pow(imagEigenvalue, 2));
            if (!(modulus >= 1.0)) continue;
            return false;
        }
        return true;
    }

    public static class VarResult {
        private final DataSet residuals;
        private final Graph collapsedVarGraph;

        public VarResult(DataSet dataSet, Graph collapsedVarGraph) {
            this.residuals = dataSet;
            this.collapsedVarGraph = collapsedVarGraph;
        }

        public DataSet getResiduals() {
            return this.residuals;
        }

        public Graph getCollapsedVarGraph() {
            return this.collapsedVarGraph;
        }
    }
}

