/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetradapp.editor;

import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Vector;
import org.apache.commons.math3.util.FastMath;

public class ScatterPlot {
    private final String x;
    private final String y;
    private final boolean includeLine;
    private final DataSet dataSet;
    private Map<Node, double[]> continuousIntervals;

    public ScatterPlot(DataSet dataSet, boolean includeLine, String x, String y) {
        this.dataSet = dataSet;
        this.x = x;
        this.y = y;
        this.includeLine = includeLine;
        this.continuousIntervals = new HashMap<Node, double[]>();
    }

    private RegressionResult getRegressionResult() {
        ArrayList<Node> regressors = new ArrayList<Node>();
        regressors.add(this.dataSet.getVariable(this.x));
        Node target = this.dataSet.getVariable(this.y);
        RegressionDataset regression = new RegressionDataset(this.dataSet);
        RegressionResult result = regression.regress(target, regressors);
        System.out.println(result);
        return result;
    }

    public double getCorrelationCoeff() {
        double[] ydata;
        DataSet dataSet = this.getDataSet();
        Matrix data = dataSet.getDoubleData();
        int _x = dataSet.getColumn(dataSet.getVariable(this.x));
        int _y = dataSet.getColumn(dataSet.getVariable(this.y));
        double[] xdata = data.getColumn(_x).toArray();
        double correlation = StatUtils.correlation(xdata, ydata = data.getColumn(_y).toArray());
        if (correlation > 1.0) {
            correlation = 1.0;
        } else if (correlation < -1.0) {
            correlation = -1.0;
        }
        return correlation;
    }

    public double getCorrelationPValue() {
        double r = this.getCorrelationCoeff();
        double fisherZ = this.fisherz(r);
        double pValue = Double.isInfinite(fisherZ) ? 0.0 : 2.0 * (1.0 - RandomUtil.getInstance().normalCdf(0.0, 1.0, FastMath.abs(fisherZ)));
        return pValue;
    }

    private double fisherz(double r) {
        return 0.5 * FastMath.sqrt((double)this.getSampleSize() - 3.0) * (FastMath.log(1.0 + r) - FastMath.log(1.0 - r));
    }

    public double getXmin() {
        double min = Double.POSITIVE_INFINITY;
        Vector<Point2D.Double> cleanedSampleValues = this.getSievedValues();
        for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
            min = FastMath.min(min, cleanedSampleValue.getX());
        }
        return min;
    }

    public double getYmin() {
        double min = Double.POSITIVE_INFINITY;
        Vector<Point2D.Double> cleanedSampleValues = this.getSievedValues();
        for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
            min = FastMath.min(min, cleanedSampleValue.getY());
        }
        return min;
    }

    public double getXmax() {
        double max = Double.NEGATIVE_INFINITY;
        Vector<Point2D.Double> cleanedSampleValues = this.getSievedValues();
        for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
            max = FastMath.max(max, cleanedSampleValue.getX());
        }
        return max;
    }

    public double getYmax() {
        double max = Double.NEGATIVE_INFINITY;
        Vector<Point2D.Double> cleanedSampleValues = this.getSievedValues();
        for (Point2D.Double cleanedSampleValue : cleanedSampleValues) {
            max = FastMath.max(max, cleanedSampleValue.getY());
        }
        return max;
    }

    public Vector<Point2D.Double> getSievedValues() {
        return this.pairs(this.x, this.y);
    }

    private int getSampleSize() {
        return this.getSievedValues().size();
    }

    public String getXvar() {
        return this.x;
    }

    public String getYvar() {
        return this.y;
    }

    public boolean isIncludeLine() {
        return this.includeLine;
    }

    public double getRegressionCoeff() {
        return this.getRegressionResult().getCoef()[1];
    }

    public double getRegressionIntercept() {
        return this.getRegressionResult().getCoef()[0];
    }

    public DataSet getDataSet() {
        return this.dataSet;
    }

    public void addConditioningVariable(String variable, double low, double high) {
        if (!(low < high)) {
            throw new IllegalArgumentException("Low must be less than high: " + low + " >= " + high);
        }
        Node node = this.dataSet.getVariable(variable);
        if (!(node instanceof ContinuousVariable)) {
            throw new IllegalArgumentException("Variable must be continuous.");
        }
        if (this.continuousIntervals.containsKey(node)) {
            throw new IllegalArgumentException("Please remove conditioning variable first.");
        }
        this.continuousIntervals.put(node, new double[]{low, high});
    }

    public void removeConditioningVariable(String variable) {
        Node node = this.dataSet.getVariable(variable);
        if (!this.continuousIntervals.containsKey(node)) {
            throw new IllegalArgumentException("Not a conditioning node: " + variable);
        }
        this.continuousIntervals.remove(node);
    }

    public void removeConditioningVariables() {
        this.continuousIntervals = new HashMap<Node, double[]>();
    }

    public int getN(String target) {
        List<Double> conditionedDataContinuous = this.getConditionedDataContinuous(target);
        return conditionedDataContinuous.size();
    }

    public double[] getContinuousData(String variable) {
        int index = this.dataSet.getColumn(this.dataSet.getVariable(variable));
        ArrayList<Double> _data = new ArrayList<Double>();
        for (int i = 0; i < this.dataSet.getNumRows(); ++i) {
            _data.add(this.dataSet.getDouble(i, index));
        }
        return this.asDoubleArray(_data);
    }

    private double[] asDoubleArray(List<Double> data) {
        double[] _data = new double[data.size()];
        for (int i = 0; i < data.size(); ++i) {
            _data[i] = data.get(i);
        }
        return _data;
    }

    private List<Double> getUnconditionedDataContinuous(String target) {
        int index = this.dataSet.getColumn(this.dataSet.getVariable(target));
        ArrayList<Double> _data = new ArrayList<Double>();
        for (int i = 0; i < this.dataSet.getNumRows(); ++i) {
            _data.add(this.dataSet.getDouble(i, index));
        }
        return _data;
    }

    private List<Double> getConditionedDataContinuous(String target) {
        if (this.continuousIntervals == null) {
            return this.getUnconditionedDataContinuous(target);
        }
        List<Integer> rows = this.getConditionedRows();
        int index = this.dataSet.getColumn(this.dataSet.getVariable(target));
        ArrayList<Double> _data = new ArrayList<Double>();
        for (Integer row : rows) {
            _data.add(this.dataSet.getDouble(row, index));
        }
        return _data;
    }

    private List<Integer> getConditionedRows() {
        ArrayList<Integer> rows = new ArrayList<Integer>();
        block0: for (int i = 0; i < this.dataSet.getNumRows(); ++i) {
            for (Node node : this.continuousIntervals.keySet()) {
                double[] range = this.continuousIntervals.get(node);
                int index = this.dataSet.getColumn(node);
                double value = this.dataSet.getDouble(i, index);
                if (value > range[0] && value < range[1]) continue;
                continue block0;
            }
            rows.add(i);
        }
        return rows;
    }

    private Vector<Point2D.Double> pairs(String x, String y) {
        Vector<Point2D.Double> cleanedVals = new Vector<Point2D.Double>();
        List<Double> _x = this.getConditionedDataContinuous(x);
        List<Double> _y = this.getConditionedDataContinuous(y);
        for (int row = 0; row < _x.size(); ++row) {
            Point2D.Double pt = new Point2D.Double();
            pt.setLocation(_x.get(row), _y.get(row));
            cleanedVals.add(pt);
        }
        return cleanedVals;
    }
}

