/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.work_in_progress;

import edu.cmu.tetrad.annotation.Experimental;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.score.SemBicScore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

@Experimental
public class MagSemBicScore
implements Score {
    private final SemBicScore score;
    private Graph mag;
    private List<Node> order;

    public MagSemBicScore(ICovarianceMatrix covariances) {
        if (covariances == null) {
            throw new NullPointerException();
        }
        this.score = new SemBicScore(covariances);
        this.mag = null;
        this.order = null;
    }

    public MagSemBicScore(DataSet dataSet, boolean precomputeCovariances) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.score = new SemBicScore(dataSet, precomputeCovariances);
        this.mag = null;
        this.order = null;
    }

    public Graph getMag() {
        return this.mag;
    }

    public void setMag(Graph mag) {
        this.mag = mag;
    }

    public void resetMag() {
        this.mag = null;
    }

    public List<Node> getOrder() {
        return this.order;
    }

    public void setOrder(List<Node> order) {
        this.order = order;
    }

    public void resetOrder() {
        this.order = null;
    }

    @Override
    public double localScore(int i, int ... js) {
        if (this.mag == null || this.order == null) {
            return this.score.localScore(i, js);
        }
        double score = 0.0;
        Node v1 = this.score.getVariables().get(i);
        ArrayList<Node> mbo = new ArrayList<Node>();
        Arrays.sort(js);
        for (Node v2 : this.order) {
            if (Arrays.binarySearch(js, this.score.getVariables().indexOf(v2)) < 0) continue;
            mbo.add(v2);
        }
        ArrayList<List<Node>> heads = new ArrayList<List<Node>>();
        ArrayList<Set<Node>> tails = new ArrayList<Set<Node>>();
        this.constructHeadsTails(heads, tails, mbo, new ArrayList<Node>(), new ArrayList<Node>(), new HashSet<Node>(), v1);
        for (int l = 0; l < heads.size(); ++l) {
            List head = (List)heads.get(l);
            Set tail = (Set)tails.get(l);
            head.remove(v1);
            int h = head.size();
            int max = h + tail.size();
            for (int j = 0; j < 1 << h; ++j) {
                ArrayList<Node> condSet = new ArrayList<Node>(tail);
                for (int k = 0; k < h; ++k) {
                    if ((j & 1 << k) <= 0) continue;
                    condSet.add((Node)head.get(k));
                }
                int[] parents = new int[j];
                for (int k = 0; k < j; ++k) {
                    parents[k] = this.score.getVariables().indexOf(condSet.get(k));
                }
                if ((max - condSet.size()) % 2 == 0) {
                    score += this.score.localScore(i, parents);
                    continue;
                }
                score -= this.score.localScore(i, parents);
            }
        }
        return score;
    }

    public double getPenaltyDiscount() {
        return this.score.getPenaltyDiscount();
    }

    public void setPenaltyDiscount(double penaltyDiscount) {
        this.score.setPenaltyDiscount(penaltyDiscount);
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        return this.localScore(y, this.append(z, x)) - this.localScore(y, z);
    }

    @Override
    public int getSampleSize() {
        return this.score.getSampleSize();
    }

    @Override
    public List<Node> getVariables() {
        return this.score.getVariables();
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > 0.0;
    }

    @Override
    public int getMaxDegree() {
        return this.score.getMaxDegree();
    }

    private void constructHeadsTails(List<List<Node>> heads, List<Set<Node>> tails, List<Node> mbo, List<Node> head, List<Node> in, Set<Node> an, Node v1) {
        head.add(v1);
        heads.add(head);
        ArrayList<Node> sib = new ArrayList<Node>();
        this.updateAncestors(an, v1);
        this.updateIntrinsics(in, sib, an, v1, mbo);
        HashSet<Node> tail = new HashSet<Node>(in);
        head.forEach(tail::remove);
        for (Node v2 : in) {
            tail.addAll(this.mag.getParents(v2));
        }
        tails.add(tail);
        for (Node v2 : sib) {
            this.constructHeadsTails(heads, tails, mbo.subList(mbo.indexOf(v2) + 1, mbo.size()), new ArrayList<Node>(head), new ArrayList<Node>(in), new HashSet<Node>(an), v2);
        }
    }

    private void updateAncestors(Set<Node> an, Node v1) {
        an.add(v1);
        for (Node v2 : this.mag.getParents(v1)) {
            this.updateAncestors(an, v2);
        }
    }

    private void updateIntrinsics(List<Node> in, List<Node> sib, Set<Node> an, Node v1, List<Node> mbo) {
        in.add(v1);
        ArrayList<Node> mb = new ArrayList<Node>(mbo);
        mb.removeAll(in);
        for (Node v3 : in.subList(0, in.size())) {
            for (Node v2 : mb) {
                Edge e = this.mag.getEdge(v2, v3);
                if (e == null || e.getEndpoint1() != Endpoint.ARROW || e.getEndpoint2() != Endpoint.ARROW) continue;
                if (an.contains(v2)) {
                    this.updateIntrinsics(in, sib, an, v2, mbo);
                    continue;
                }
                sib.add(v2);
            }
        }
    }

    @Override
    public String toString() {
        return "MAG(" + this.score + ")";
    }
}

