from modules.stein import graph_inference
from modules.utils import pretty_evaluate, num_errors

class CausalDiscovery:
    """
    Class to make inference using DAS algorithm on observational data
    """
    def __init__(self, X, R, A_truth, **kwargs):
        """
        Args:
            X: NxD matrx of the data
            A_truth: NxD adjacency ground truth of the graph
        """
        self.X = X
        self.R = R
        self.A_truth = A_truth

        self.algorithm = kwargs['algorithm']
        self.gamma = kwargs['gamma']
        self.alpha = kwargs['alpha']
        self.n_cv = kwargs['cv']
        self.d = kwargs['d']
        self.s0 = kwargs['s0']
        self.eta_G = kwargs['eta_G']
        self.cam_cutoff = kwargs['cam_cutoff']
        self.pruning = kwargs['pruning']
        self.sid = kwargs['sid']
        self.noise = kwargs['noise_type']
        self.K = kwargs['K']
        self.noise_std = kwargs['nstd']
        
    def algorithm_inference(self):
        return graph_inference(
            self.X, self.eta_G, alpha=self.alpha, gamma=self.gamma, n_cv = self.n_cv, pruning=self.pruning, algorithm=self.algorithm
        )

    def pretty_print(self, A_pred, top_order, tot_time, sid):
        top_order_err = num_errors(top_order, self.A_truth)
        return  pretty_evaluate(
            self.A_truth, A_pred, top_order_err, tot_time, sid, self.s0, self.alpha, self.gamma, self.n_cv, self.noise, self.algorithm
        )

    def inference(self):
        A, top_order, order_time, tot_time =  self.algorithm_inference()
        sid=False
        if self.d <= 200:
            sid = True
        pretty = self.pretty_print(A, top_order, tot_time, sid)
        print(pretty)
        return A, top_order, order_time, tot_time
