/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.algcomparison.algorithm;

import edu.cmu.tetrad.algcomparison.algorithm.Algorithm;
import edu.cmu.tetrad.data.BootstrapSampler;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.ForkJoinPoolInstance;
import edu.cmu.tetrad.util.Parameters;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import org.apache.commons.math3.util.FastMath;

public class StARS
implements Algorithm {
    private static final long serialVersionUID = 23L;
    private final double low;
    private final double high;
    private final String parameter;
    private final Algorithm algorithm;

    public StARS(Algorithm algorithm, String parameter, double low, double high) {
        if (low >= high) {
            throw new IllegalArgumentException("Must have low < high");
        }
        this.algorithm = algorithm;
        this.low = low;
        this.high = high;
        this.parameter = parameter;
    }

    private static double getD(Parameters params, String paramName, double paramValue, List<DataSet> samples, Algorithm algorithm) {
        params.set(paramName, (Object)paramValue);
        ArrayList graphs = new ArrayList();
        ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();
        boolean chunk = true;
        class StabilityAction
        extends RecursiveAction {
            private final int chunk;
            private final int from;
            private final int to;
            final /* synthetic */ Algorithm val$algorithm;
            final /* synthetic */ List val$samples;
            final /* synthetic */ Parameters val$params;
            final /* synthetic */ List val$graphs;

            StabilityAction(int chunk, int from, int to) {
                this.val$algorithm = algorithm;
                this.val$samples = list;
                this.val$params = parameters;
                this.val$graphs = list2;
                this.chunk = chunk;
                this.from = from;
                this.to = to;
            }

            @Override
            protected void compute() {
                if (this.to - this.from <= this.chunk) {
                    for (int s = this.from; s < this.to; ++s) {
                        Graph e = this.val$algorithm.search((DataModel)this.val$samples.get(s), this.val$params);
                        e = GraphUtils.replaceNodes(e, ((DataSet)this.val$samples.get(0)).getVariables());
                        this.val$graphs.add(e);
                    }
                } else {
                    int mid = (this.to + this.from) / 2;
                    StabilityAction left = new StabilityAction(this.chunk, this.from, mid);
                    StabilityAction right = new StabilityAction(this.chunk, mid, this.to);
                    left.fork();
                    right.compute();
                    left.join();
                }
            }
        }
        pool.invoke(new StabilityAction(1, 0, samples.size()));
        int p = samples.get(0).getNumColumns();
        List<Node> nodes = ((Graph)graphs.get(0)).getNodes();
        double D = 0.0;
        int count = 0;
        for (int i = 0; i < p; ++i) {
            for (int j = i + 1; j < p; ++j) {
                double theta = 0.0;
                Node x = nodes.get(i);
                Node y = nodes.get(j);
                for (Graph graph : graphs) {
                    if (!graph.isAdjacentTo(x, y)) continue;
                    theta += 1.0;
                }
                double xsi = 2.0 * (theta /= (double)graphs.size()) * (1.0 - theta);
                D += xsi;
                ++count;
            }
        }
        return D /= (double)count;
    }

    private static double getValue(double value, Parameters parameters) {
        if (parameters.getBoolean("logScale")) {
            return (double)FastMath.round(FastMath.pow(10.0, value) * 1.0E9) / 1.0E9;
        }
        return (double)FastMath.round(value * 1.0E9) / 1.0E9;
    }

    @Override
    public Graph search(DataModel dataSet, Parameters parameters) {
        DataSet _dataSet = (DataSet)dataSet;
        double percentageB = parameters.getDouble("percentSubsampleSize");
        double beta = parameters.getDouble("StARS.cutoff");
        int numSubsamples = parameters.getInt("numSubsamples");
        Parameters _parameters = new Parameters(parameters);
        ArrayList<DataSet> samples = new ArrayList<DataSet>();
        for (int i = 0; i < numSubsamples; ++i) {
            BootstrapSampler sampler = new BootstrapSampler();
            sampler.setWithoutReplacements(true);
            samples.add(sampler.sample(_dataSet, (int)(percentageB * (double)_dataSet.getNumRows())));
        }
        double maxD = Double.NEGATIVE_INFINITY;
        double _lambda = Double.NaN;
        for (double lambda = this.low; lambda <= this.high; lambda += 0.5) {
            double D = StARS.getD(parameters, this.parameter, lambda, samples, this.algorithm);
            System.out.println("lambda = " + lambda + " D = " + D);
            if (!(D > maxD) || !(D < beta)) continue;
            maxD = D;
            _lambda = lambda;
        }
        System.out.println("FINAL: lambda = " + _lambda + " D = " + maxD);
        System.out.println(this.parameter + " = " + StARS.getValue(_lambda, parameters));
        _parameters.set(this.parameter, (Object)StARS.getValue(_lambda, parameters));
        return this.algorithm.search(dataSet, _parameters);
    }

    @Override
    public Graph getComparisonGraph(Graph graph) {
        return this.algorithm.getComparisonGraph(graph);
    }

    @Override
    public String getDescription() {
        return "StARS for " + this.algorithm.getDescription() + " parameter = " + this.parameter;
    }

    @Override
    public DataType getDataType() {
        return this.algorithm.getDataType();
    }

    @Override
    public List<String> getParameters() {
        List<String> parameters = this.algorithm.getParameters();
        parameters.add("depth");
        parameters.add("verbose");
        parameters.add("StARS.percentageB");
        parameters.add("StARS.tolerance");
        parameters.add("StARS.cutoff");
        parameters.add("numSubsamples");
        return parameters;
    }
}

