from modules.algorithms.base import CausalDiscovery
from modules.stein import graph_inference
from modules.utils import pretty_evaluate, num_errors

class DAS(CausalDiscovery):
    """
    Class to make inference using DAS algorithm on observational data
    """
    def __init__(self, X, R, A_truth, **kwargs):
        """
        Args:
            X: NxD matrx of the data
            A_truth: NxD adjacency ground truth of the graph
        """
        super().__init__(X, R, A_truth, **kwargs)
        self.K = kwargs['K']
        self.noise_std = kwargs['nstd']

    def algorithm_inference(self):
        return graph_inference(
            self.X, self.eta_G, alpha=self.alpha, gamma=self.gamma, n_cv = self.n_cv, pruning=self.pruning, algorithm=self.algorithm
        )

    def pretty_print(self, A_pred, top_order, tot_time, sid):
        top_order_err = num_errors(top_order, self.A_truth)
        return  pretty_evaluate(
            self.A_truth, A_pred, top_order_err, tot_time, sid, self.s0, self.alpha, self.gamma, self.n_cv, self.noise
        )

      

class SCORE(CausalDiscovery):
    """
    Class to make inference using DAS algorithm on observational data
    """
    def __init__(self, X, A_truth, **kwargs):
        """
        Args:
            X: NxD matrx of the data
            A_truth: NxD adjacency ground truth of the graph
        """
        super().__init__(X, A_truth, **kwargs)
        self.pns = kwargs['pns']

    def algorithm_inference(self):
        return graph_inference(
            self.X, self.eta_G, self.eta_H, self.cam_cutoff, pruning=self.pruning, threshold=self.threshold, pns = self.pns
        )

    def pretty_print(self, A, tot_time, top_order_err=None, SCORE_time=None):
        return  pretty_evaluate(
            self.pruning, self.threshold, self.A_truth, A, top_order_err, SCORE_time, tot_time, self.sid, s0=self.s0, K=self.K
        )