/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionDatasetGeneralized;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.Ges;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class TimeSeriesUtils {
    public static DataSet var(DataSet timeSeries, int numLags) {
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, numLags);
        ArrayList<Node> regressors = new ArrayList<Node>();
        for (int i = timeSeries.getNumColumns(); i < timeLags.getNumColumns(); ++i) {
            regressors.add(timeLags.getVariable(i));
        }
        RegressionDataset regression = new RegressionDataset(timeLags);
        DenseDoubleMatrix2D residuals = new DenseDoubleMatrix2D(timeLags.getNumRows(), timeSeries.getNumColumns());
        for (int i = 0; i < timeSeries.getNumColumns(); ++i) {
            Node target = timeLags.getVariable(i);
            RegressionResult result = regression.regress(target, regressors);
            DoubleMatrix1D residualsColumn = result.getResiduals();
            residuals.viewColumn(i).assign(residualsColumn);
        }
        return ColtDataSet.makeContinuousData(timeSeries.getVariables(), residuals);
    }

    public static DataSet var2(DataSet timeSeries, int numLags) {
        ArrayList<Node> missingVariables = new ArrayList<Node>();
        for (Node node : timeSeries.getVariables()) {
            int index = timeSeries.getVariables().indexOf(node);
            boolean missing = true;
            for (int i = 0; i < timeSeries.getNumRows(); ++i) {
                if (Double.isNaN(timeSeries.getDouble(i, index))) continue;
                missing = false;
                break;
            }
            if (!missing) continue;
            missingVariables.add(node);
        }
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, numLags);
        RegressionDataset regression = new RegressionDataset(timeLags);
        DenseDoubleMatrix2D residuals = new DenseDoubleMatrix2D(timeLags.getNumRows(), timeSeries.getNumColumns());
        for (int i = 0; i < timeSeries.getNumColumns(); ++i) {
            Node target = timeLags.getVariable(i);
            int index = timeSeries.getVariables().indexOf(target);
            if (missingVariables.contains(target)) {
                for (int i2 = 0; i2 < residuals.rows(); ++i2) {
                    residuals.set(i2, index, Double.NaN);
                }
                continue;
            }
            ArrayList<Node> regressors = new ArrayList<Node>();
            for (int i2 = timeSeries.getNumColumns(); i2 < timeLags.getNumColumns(); ++i2) {
                int varIndex = i2 % timeSeries.getNumColumns();
                Node var = timeSeries.getVariable(varIndex);
                if (missingVariables.contains(var)) continue;
                regressors.add(timeLags.getVariable(i2));
            }
            RegressionResult result = regression.regress(target, regressors);
            DoubleMatrix1D residualsColumn = result.getResiduals();
            residuals.viewColumn(i).assign(residualsColumn);
        }
        return ColtDataSet.makeContinuousData(timeSeries.getVariables(), residuals);
    }

    private int[] eliminateMissing(int[] parents, int dataIndex, DataSet dataSet, List<Node> missingVariables) {
        ArrayList<Integer> _parents = new ArrayList<Integer>();
        for (int k : parents) {
            if (missingVariables.contains(dataSet.getVariable(k))) continue;
            _parents.add(k);
        }
        int[] _parents2 = new int[_parents.size()];
        for (int i = 0; i < _parents.size(); ++i) {
            _parents2[i] = (Integer)_parents.get(i);
        }
        return _parents2;
    }

    public static VarResult structuralVar(DataSet timeSeries, int numLags) {
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, numLags);
        Knowledge knowledge = new Knowledge(timeLags.getKnowledge());
        for (int i = 0; i <= numLags; ++i) {
            knowledge.setTierForbiddenWithin(i, true);
        }
        Ges search = new Ges(timeLags);
        search.setKnowledge(knowledge);
        Graph graph = search.search();
        EdgeListGraph collapsedVarGraph = new EdgeListGraph(timeSeries.getVariables());
        for (Edge edge : graph.getEdges()) {
            Node node2;
            String node1_before = edge.getNode1().getName();
            String node2_before = edge.getNode2().getName();
            String node1_after = node1_before.substring(0, node1_before.indexOf("."));
            String node2_after = node2_before.substring(0, node2_before.indexOf("."));
            Node node1 = collapsedVarGraph.getNode(node1_after);
            Edge _edge = new Edge(node1, node2 = collapsedVarGraph.getNode(node2_after), edge.getEndpoint1(), edge.getEndpoint2());
            if (collapsedVarGraph.containsEdge(_edge)) continue;
            collapsedVarGraph.addEdge(_edge);
        }
        DenseDoubleMatrix2D residuals = new DenseDoubleMatrix2D(timeLags.getNumRows(), timeSeries.getNumColumns());
        RegressionDatasetGeneralized regression = new RegressionDatasetGeneralized(timeLags);
        for (int i = 0; i < timeSeries.getNumColumns(); ++i) {
            Node target = timeLags.getVariable(i);
            ArrayList<Node> regressors = new ArrayList<Node>();
            for (int j = 0; j <= 0; ++j) {
                Node variable = timeLags.getVariable(i + j * timeSeries.getNumColumns());
                regressors.addAll(graph.getParents(variable));
            }
            RegressionResult result = regression.regress(target, regressors);
            DoubleMatrix1D residualsColumn = result.getResiduals();
            residuals.viewColumn(i).assign(residualsColumn);
        }
        return new VarResult(ColtDataSet.makeContinuousData(timeSeries.getVariables(), residuals), collapsedVarGraph);
    }

    public static DataSet createShiftedData(DataSet data, int[] shifts) {
        DoubleMatrix2D data2 = data.getDoubleData();
        int min = Integer.MAX_VALUE;
        int max = Integer.MIN_VALUE;
        for (int i1 = 0; i1 < shifts.length; ++i1) {
            if (shifts[i1] < min) {
                min = shifts[i1];
            }
            if (shifts[i1] <= max) continue;
            max = shifts[i1];
        }
        int shiftRange = max - min;
        int[] _shifts = new int[shifts.length];
        for (int i = 0; i < shifts.length; ++i) {
            _shifts[i] = shiftRange - (shifts[i] - min);
        }
        if (shiftRange > data2.rows()) {
            throw new IllegalArgumentException("Range of shifts greater than sample size.");
        }
        int shiftedDataLength = data2.rows() - shiftRange;
        DenseDoubleMatrix2D shiftedData = new DenseDoubleMatrix2D(shiftedDataLength, data2.columns());
        for (int j = 0; j < shiftedData.columns(); ++j) {
            for (int i = 0; i < shiftedDataLength; ++i) {
                shiftedData.set(i, j, data2.get(i + _shifts[j], j));
            }
        }
        return ColtDataSet.makeContinuousData(data.getVariables(), shiftedData);
    }

    public static double[] getSelfLoopCoefs(DataSet timeSeries) {
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, 1);
        double[] coefs = new double[timeSeries.getNumColumns()];
        for (int j = 0; j < timeSeries.getNumColumns(); ++j) {
            Node target = timeLags.getVariable(j);
            Node selfLoop = timeLags.getVariable(j + timeSeries.getNumColumns());
            List<Node> regressors = Collections.singletonList(selfLoop);
            RegressionDatasetGeneralized regression = new RegressionDatasetGeneralized(timeLags);
            RegressionResult result = regression.regress(target, regressors);
            coefs[j] = result.getCoef()[1];
        }
        return coefs;
    }

    public static double sumOfArCoefficients(DataSet timeSeries, int numLags) {
        DataSet timeLags = TimeSeriesUtils.createLagData(timeSeries, numLags);
        ArrayList<Node> regressors = new ArrayList<Node>();
        for (int i = timeSeries.getNumColumns(); i < timeLags.getNumColumns(); ++i) {
            regressors.add(timeLags.getVariable(i));
        }
        RegressionDatasetGeneralized regression = new RegressionDatasetGeneralized(timeLags);
        DenseDoubleMatrix2D residuals = new DenseDoubleMatrix2D(timeLags.getNumRows(), timeSeries.getNumColumns());
        double sum = 0.0;
        int n = 0;
        for (int i = 0; i < timeSeries.getNumColumns(); ++i) {
            Node target = timeLags.getVariable(i);
            RegressionResult result = regression.regress(target, regressors);
            double[] coef = result.getCoef();
            for (int k = 0; k < coef.length; ++k) {
                sum += coef[k] * coef[k];
                ++n;
            }
            DoubleMatrix1D residualsColumn = result.getResiduals();
            residuals.viewColumn(i).assign(residualsColumn);
        }
        return sum / (double)n;
    }

    public static DataSet difference(DataSet data, int d) {
        if (d == 0) {
            return data;
        }
        DoubleMatrix2D _data = data.getDoubleData();
        for (int k = 1; k <= d; ++k) {
            DenseDoubleMatrix2D _data2 = new DenseDoubleMatrix2D(_data.rows() - 1, _data.columns());
            for (int i = 1; i < _data.rows(); ++i) {
                for (int j = 0; j < _data.columns(); ++j) {
                    _data2.set(i - 1, j, _data.get(i, j) - _data.get(i - 1, j));
                }
            }
            _data = _data2;
        }
        return ColtDataSet.makeContinuousData(data.getVariables(), _data);
    }

    public static DataSet createLagData(DataSet data, int numLags) {
        Node laggedNode;
        List<Node> variables = data.getVariables();
        int dataSize = variables.size();
        int laggedRows = data.getNumRows() - numLags;
        Knowledge knowledge = new Knowledge();
        Node[][] laggedNodes = new Node[numLags + 1][dataSize];
        ArrayList<Node> newVariables = new ArrayList<Node>((numLags + 1) * dataSize + 1);
        for (int lag = 0; lag <= numLags; ++lag) {
            for (int col = 0; col < dataSize; ++col) {
                Node node = variables.get(col);
                String varName = node.getName();
                if (node instanceof ContinuousVariable) {
                    laggedNode = new ContinuousVariable(varName + "." + lag);
                } else if (node instanceof DiscreteVariable) {
                    DiscreteVariable var = (DiscreteVariable)node;
                    laggedNode = new DiscreteVariable(var);
                    var.setName(varName + "." + lag);
                } else {
                    throw new IllegalStateException("Node must be either continuous or discrete");
                }
                newVariables.add(laggedNode);
                laggedNode.setCenter(80 * col + 50, 80 * (numLags - lag) + 50);
                laggedNodes[lag][col] = laggedNode;
                knowledge.addToTier(numLags - lag, laggedNode.getName());
            }
        }
        ColtDataSet laggedData = new ColtDataSet(laggedRows, newVariables);
        for (int lag = 0; lag < numLags + 1; ++lag) {
            for (int col = 0; col < dataSize; ++col) {
                for (int row = 0; row < laggedRows; ++row) {
                    laggedNode = laggedNodes[lag][col];
                    if (laggedNode instanceof ContinuousVariable) {
                        double value = data.getDouble(row + numLags - lag, col);
                        laggedData.setDouble(row, col + lag * dataSize, value);
                        continue;
                    }
                    int value = data.getInt(row + numLags - lag, col);
                    laggedData.setInt(row, col + lag * dataSize, value);
                }
            }
        }
        knowledge.setDefaultToKnowledgeLayout(true);
        laggedData.setKnowledge(knowledge);
        return laggedData;
    }

    public static class VarResult {
        private DataSet residuals;
        private Graph collapsedVarGraph;

        public VarResult(DataSet dataSet, Graph collapsedVarGraph) {
            this.residuals = dataSet;
            this.collapsedVarGraph = collapsedVarGraph;
        }

        public DataSet getResiduals() {
            return this.residuals;
        }

        public Graph getCollapsedVarGraph() {
            return this.collapsedVarGraph;
        }
    }
}

