/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.list.DoubleArrayList;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.jet.stat.Descriptive;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.RegressionDatasetGeneralized;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.IndTestFisherZ;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.TreeMap;

public class LingamPatternOld {
    private int numSamples = 15;

    public Result search(Graph pattern, DataSet dataSet) throws IllegalArgumentException {
        return this.search(SearchGraphUtils.getDagsInPatternMeek(pattern, new Knowledge()), dataSet);
    }

    public Result search(Graph pattern, DataSet dataSet, Knowledge knowledge) throws IllegalArgumentException {
        return this.search(SearchGraphUtils.getDagsInPatternMeek(pattern, knowledge), dataSet);
    }

    public Result search(List<Graph> dags, DataSet dataSet) throws IllegalArgumentException {
        if (dags.isEmpty()) {
            throw new IllegalArgumentException("No input dags");
        }
        System.out.println("DAGS GIVEN TO LINGAM PATTERN:");
        for (int i = 0; i < dags.size(); ++i) {
            System.out.println("#" + i + ": " + dags.get(i));
        }
        DoubleMatrix2D data = dataSet.getDoubleData();
        List<Node> variables = dataSet.getVariables();
        int bootstrapSize = data.rows() / 2;
        if (dags.size() == 0) {
            return new Result(new ArrayList<Graph>(), new ArrayList<Integer>(), this.numSamples);
        }
        double[][] scores = new double[dags.size()][this.getNumSamples()];
        for (int k = 0; k < this.getNumSamples(); ++k) {
            System.out.println("Sample " + k);
            DoubleMatrix2D sample = this.getBootstrapSample(data, bootstrapSize);
            for (int i = 0; i < dags.size(); ++i) {
                Graph dag = dags.get(i);
                scores[i][k] = this.getScore(dag, sample, variables);
            }
        }
        final TreeMap<Integer, Integer> highestCounts = new TreeMap<Integer, Integer>();
        for (int i = 0; i < this.getNumSamples(); ++i) {
            double maxScore = 0.0;
            for (int j = 0; j < dags.size(); ++j) {
                if (!(scores[j][i] > maxScore)) continue;
                maxScore = scores[j][i];
            }
            int numMax = 0;
            for (int j = 0; j < dags.size(); ++j) {
                if (scores[j][i] != maxScore) continue;
                ++numMax;
                if (!highestCounts.containsKey(j)) {
                    highestCounts.put(j, 1);
                    continue;
                }
                highestCounts.put(j, (Integer)highestCounts.get(j) + 1);
            }
        }
        int maxIndex = -1;
        Iterator i$ = highestCounts.keySet().iterator();
        while (i$.hasNext()) {
            int i = (Integer)i$.next();
            if ((Integer)highestCounts.get(i) <= maxIndex) continue;
            maxIndex = (Integer)highestCounts.get(i);
        }
        ArrayList outputIndices = new ArrayList(highestCounts.keySet());
        Collections.sort(outputIndices, new Comparator<Integer>(){

            @Override
            public int compare(Integer o1, Integer o2) {
                return (Integer)highestCounts.get(o2) - (Integer)highestCounts.get(o1);
            }
        });
        ArrayList<Graph> outputDags = new ArrayList<Graph>();
        for (int i = 0; i < outputIndices.size(); ++i) {
            outputDags.add(dags.get((Integer)outputIndices.get(i)));
        }
        ArrayList<Integer> outputCounts = new ArrayList<Integer>();
        for (int i = 0; i < outputIndices.size(); ++i) {
            outputCounts.add((Integer)highestCounts.get(outputIndices.get(i)));
        }
        if (outputDags.isEmpty()) {
            throw new IllegalArgumentException("Not output dags");
        }
        return new Result(outputDags, outputCounts, this.numSamples);
    }

    private Double getScore(Graph dag, DoubleMatrix2D data, List<Node> variables) {
        RegressionDatasetGeneralized regression = new RegressionDatasetGeneralized(data, variables);
        List<Node> nodes = dag.getNodes();
        double score = 0.0;
        DenseDoubleMatrix2D residuals = new DenseDoubleMatrix2D(data.rows(), data.columns());
        for (int i = 0; i < nodes.size(); ++i) {
            Node _target = nodes.get(i);
            List<Node> _regressors = dag.getParents(_target);
            Node target = this.getVariable(variables, _target.getName());
            ArrayList<Node> regressors = new ArrayList<Node>();
            for (Node _regressor : _regressors) {
                Node variable = this.getVariable(variables, _regressor.getName());
                regressors.add(variable);
            }
            RegressionResult result = regression.regress(target, regressors);
            DoubleMatrix1D residualsColumn = result.getResiduals();
            residuals.viewColumn(i).assign(residualsColumn);
            DoubleArrayList residualsArray = new DoubleArrayList(residualsColumn.toArray());
            double mean = Descriptive.mean(residualsArray);
            double std = Descriptive.standardDeviation(Descriptive.variance(residualsArray.size(), Descriptive.sum(residualsArray), Descriptive.sumOfSquares(residualsArray)));
            for (int i2 = 0; i2 < residualsArray.size(); ++i2) {
                residualsArray.set(i2, (residualsArray.get(i2) - mean) / std);
                residualsArray.set(i2, Math.abs(residualsArray.get(i2)));
            }
            double _mean = Descriptive.mean(residualsArray);
            double diff = _mean - Math.sqrt(0.6366197723675814);
            score += diff * diff;
        }
        boolean numIndepencies = false;
        IndTestFisherZ test = new IndTestFisherZ(residuals, nodes, 0.01);
        return score;
    }

    private Node getVariable(List<Node> variables, String name) {
        for (Node node : variables) {
            if (!name.equals(node.getName())) continue;
            return node;
        }
        return null;
    }

    private DoubleMatrix2D getBootstrapSample(DoubleMatrix2D dataSet, int sampleSize) {
        int actualSampleSize = dataSet.rows();
        int[] rows = new int[sampleSize];
        for (int i = 0; i < rows.length; ++i) {
            rows[i] = RandomUtil.getInstance().nextInt(actualSampleSize);
        }
        int[] cols = new int[dataSet.columns()];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = i;
        }
        return dataSet.viewSelection(rows, cols).copy();
    }

    public int getNumSamples() {
        return this.numSamples;
    }

    public void setNumSamples(int numSamples) {
        if (numSamples < 1) {
            throw new IllegalArgumentException("Must use at least one sample: " + numSamples);
        }
        this.numSamples = numSamples;
    }

    public static class Result {
        private List<Graph> dags;
        private List<Integer> counts;
        private int numSamples;

        public Result(List<Graph> dags, List<Integer> counts, int numSamples) {
            this.setDags(dags);
            this.setCounts(counts);
            this.numSamples = numSamples;
        }

        public List<Graph> getDags() {
            return this.dags;
        }

        public void setDags(List<Graph> dags) {
            this.dags = dags;
        }

        public List<Integer> getCounts() {
            return this.counts;
        }

        public void setCounts(List<Integer> counts) {
            this.counts = counts;
        }

        public int getNumSamples() {
            return this.numSamples;
        }

        public String toString() {
            StringBuilder buf = new StringBuilder();
            for (int i = 0; i < this.dags.size(); ++i) {
                buf.append("#" + i).append("\n");
                buf.append(this.dags.get(i));
                buf.append(this.counts.get(i)).append(" votes\n");
            }
            return buf.toString();
        }
    }
}

