/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import java.util.ArrayList;
import java.util.List;

public class Washdown {
    private ICovarianceMatrix cov;
    private DataSet dataSet;
    private final List<Node> variables;
    private final double alpha;

    public Washdown(ICovarianceMatrix cov, double alpha) {
        this.cov = cov;
        this.variables = cov.getVariables();
        this.alpha = alpha;
    }

    public Washdown(DataSet data, double alpha) {
        this.dataSet = data;
        this.variables = data.getVariables();
        this.alpha = alpha;
    }

    public Graph search() {
        double pValue;
        List<List<Node>> clusters = new ArrayList<List<Node>>();
        clusters.add(new ArrayList<Node>(this.variables));
        do {
            clusters = this.purify(clusters);
            List<Node> disgards = this.getDiscards(clusters, this.variables);
            clusters.add(disgards);
            pValue = this.pValue(clusters);
            System.out.println("\nSearch PValue = " + pValue + " clusters = " + clusters + "\n");
        } while (pValue < this.alpha);
        return this.pureMeasurementModel(clusters);
    }

    private List<Node> getDiscards(List<List<Node>> clusters, List<Node> variables) {
        ArrayList<Node> disgards = new ArrayList<Node>();
        for (Node node : variables) {
            boolean found = false;
            for (List<Node> cluster : clusters) {
                if (!cluster.contains(node)) continue;
                found = true;
                break;
            }
            if (found) continue;
            disgards.add(node);
        }
        return disgards;
    }

    private List<List<Node>> purify(List<List<Node>> clusters) {
        ArrayList<Node> keep = new ArrayList<Node>(this.variables);
        double bestGof = this.gof(clusters);
        System.out.println("Purify Best GOF = " + bestGof + " clusters = " + clusters);
        while (!(this.pValue(clusters) > this.alpha)) {
            Node bestNode = null;
            for (Node node : keep) {
                List<List<Node>> _clusters = this.removeVar(node, clusters);
                double _gof = this.gof(_clusters);
                System.out.println("     GOF = " + this.gof(_clusters) + "P value = " + this.pValue(_clusters) + " clusters = " + _clusters);
                if (!(_gof < bestGof)) continue;
                bestGof = _gof;
                bestNode = node;
            }
            if (bestNode == null) {
                return clusters;
            }
            clusters = this.removeVar(bestNode, clusters);
            keep.remove(bestNode);
        }
        return clusters;
    }

    private List<List<Node>> removeVar(Node node, List<List<Node>> clusters) {
        ArrayList<List<Node>> _clusters = new ArrayList<List<Node>>();
        for (List<Node> cluster : clusters) {
            ArrayList<Node> _cluster = new ArrayList<Node>(cluster);
            _cluster.remove(node);
            if (cluster.isEmpty()) continue;
            _clusters.add(_cluster);
        }
        return _clusters;
    }

    private double gof(List<List<Node>> clusters) {
        clusters = this.removeEmpty(clusters);
        Graph graph = this.pureMeasurementModel(clusters);
        SemPm pm = new SemPm(graph);
        SemEstimator estimator = this.cov != null ? new SemEstimator(this.cov, pm) : new SemEstimator(this.dataSet, pm);
        SemIm est = estimator.estimate();
        return est.getBicScore();
    }

    private double pValue(List<List<Node>> clusters) {
        clusters = this.removeEmpty(clusters);
        Graph graph = this.pureMeasurementModel(clusters);
        SemPm pm = new SemPm(graph);
        SemEstimator estimator = this.cov != null ? new SemEstimator(this.cov, pm) : new SemEstimator(this.dataSet, pm);
        SemIm est = estimator.estimate();
        return est.getPValue();
    }

    private List<List<Node>> removeEmpty(List<List<Node>> clusters) {
        ArrayList<List<Node>> _clusters = new ArrayList<List<Node>>();
        for (List<Node> cluster : clusters) {
            if (cluster.isEmpty()) continue;
            _clusters.add(cluster);
        }
        return _clusters;
    }

    private Graph pureMeasurementModel(List<List<Node>> clusters) {
        int i;
        EdgeListGraph G = new EdgeListGraph();
        ArrayList<GraphNode> latents = new ArrayList<GraphNode>();
        for (i = 0; i < clusters.size(); ++i) {
            GraphNode node = new GraphNode("L" + i);
            node.setNodeType(NodeType.LATENT);
            latents.add(node);
            G.addNode(node);
        }
        for (i = 0; i < latents.size(); ++i) {
            for (int j = i + 1; j < latents.size(); ++j) {
                G.addBidirectedEdge((Node)latents.get(i), (Node)latents.get(j));
            }
        }
        for (i = 0; i < clusters.size(); ++i) {
            for (Node node : clusters.get(i)) {
                G.addNode(node);
                G.addDirectedEdge((Node)latents.get(i), node);
            }
        }
        return G;
    }
}

