from endo_regime_pcmci.endo_cit import EndoCIT
from endo_regime_pcmci.mask_generator import MaskGenerator


from tigramite.pcmci import PCMCI
import numpy as np


class SparseEndoCIT(EndoCIT):
    def __init__(self, mixed_cit, cont_cit, context_vars, context_values=None):
        super().__init__(mixed_cit=mixed_cit, cont_cit=cont_cit, context_vars=context_vars, context_values=context_values)
        self.mask_generator = MaskGenerator('sparse')
        self.lagged_parents = None

    def set_lagged_parents(self, lagged_parents):
        self.lagged_parents = lagged_parents

    def check_context_in_Z(self, X, Y, Z):
        """
        Identifies context variables present in the conditioning set Z that are also lagged parents of variables in X or Y.

        Args:
            X (list): List of variables (tuples) representing the first set of variables.
            Y (list): List of variables (tuples) representing the second set of variables.
            Z (list): List of variables (tuples) representing the conditioning set.

        Returns:
            list: A list of tuples where each tuple contains the context variable name and its corresponding lag,
                  for context variables in Z that are lagged parents of variables in X or Y.
        """
        context_vars = []
        # context_lags = []
        for var in Z: 
            if var[0] in self.context_vars:
                if self.lagged_parents is not None:
                    for v in X + Y:
                        for parent in self.lagged_parents[v[0]]:
                            if parent[0] == var[0]:
                                context_vars.append((var[0], parent[1]))
        return context_vars

    def check_context_in_X_Y(self, X, Y):
        """
        Checks whether any of the context variables are present as the first element 
        in any tuple in X or Y.

        Args:
            X (list of tuple): List of tuples, where each tuple's first element is compared to context variables.
            Y (list of tuple): List of tuples, where each tuple's first element is compared to context variables.

        Returns:
            bool: True if any context variable matches the first element of any tuple in X or Y, False otherwise.
        """
        for context_var in self.context_vars:
            for x in X:
                if context_var == x[0]:
                    return True
        for context_var in self.context_vars:
            for y in Y:
                if context_var == y[0]:
                    return True
        return False
    

    def get_masked_data(self, X, Y, Z, context_lags, tau_max=0):
        """
        Constructs and returns masked data arrays for conditional independence testing.
        Parameters
        ----------
        X : list
            List of variable names for the first variable set.
        Y : list
            List of variable names for the second variable set.
        Z : list
            List of variable names for the conditioning set.
        context_lags : list or array-like
            Lags to use for context masking in the data.
        tau_max : int, optional
            Maximum time lag to consider for masking (default is 0).
        Returns
        -------
        masked_array : np.ndarray or None
            The masked data array with context variables removed. Returns None if the resulting array is too small.
        xyz : list or None
            List of variable names after masking. Returns None if the resulting array is too small.
        XYZ : list or None
            List of variable names and lags after masking. Returns None if the resulting array is too small.
        masked_type_array : np.ndarray or None
            The masked type array with context variables removed. Returns None if the resulting array is too small.
        Notes
        -----
        - Context variables specified in `self.context_vars` are identified in `Z` and removed from the data arrays.
        - Masking is performed using `self.mask_generator.generate_lagged_sparse_mask`.
        - If the masked array has fewer than 2 rows or 3 columns, all outputs are set to None.
        """
        array, xyz, XYZ, type_array = self.mixed_cit.dataframe.construct_array(X=X, Y=Y, Z=Z,
                                              tau_max=2 * tau_max,
                                              return_cleaned_xyz=True,
                                              do_checks=True,
                                              remove_overlaps=True,
                                              verbosity=self.verbosity)
        
        context_idxs = []
        context_lags_vals = []
        current_values = []
        already_added = []
        for j in range(len(context_lags)):
            for i in range(len(Z)): 
                if Z[i][0] == context_lags[j][0]:
                    if context_lags[j][0] not in already_added:
                        context_idxs.append(i + len(X) + len(Y))
                        context_lags_vals.append(context_lags[j][1])
                        # print('context', self.context_values)
                        current_values.append(self.context_values[j])
                        already_added.append(context_lags[j][0])

        mask = self.mask_generator.generate_lagged_mask_sparse(
            data=array,
            context_idxs=context_idxs,
            context_lags=context_lags_vals,
            tau_max=tau_max,
            values=current_values)
        
        
        sorted_idxs = sorted(context_idxs, reverse=True)
        masked_array = array[:, mask.flatten() == False]
        masked_type_array = type_array[:, mask.flatten() == False]

        for idx in sorted_idxs:
            masked_array = np.delete(masked_array, idx, axis=0)
            masked_type_array = np.delete(masked_type_array, idx, axis=0)
            del XYZ[-1][idx - (len(X) + len(Y))]
            
        xyz = xyz[:-len(sorted_idxs)]


        if masked_array.shape[0] < 2 or masked_array.shape[1] <= 2:
            return None, None, None, None

        return masked_array, xyz, XYZ, masked_type_array
    
    def run_context_cond_ind(self, X, Y, Z, context_lags, tau_max=0, alpha_or_thresh=None):
        """
        Runs a context-specific conditional independence test between variables X and Y given Z, 
        using masked data and context lags.
        Parameters
        ----------
        X : array-like
            The first variable or time series to test.
        Y : array-like
            The second variable or time series to test.
        Z : array-like or None
            The conditioning set of variables or time series.
        context_lags : list or array-like
            Lags to use for context masking in the data.
        tau_max : int, optional (default=0)
            Maximum lag to consider for the test.
        alpha_or_thresh : float or None, optional
            Significance level (alpha) or threshold for the conditional independence test.
        Returns
        -------
        val : float or None
            Test statistic value from the conditional independence test.
        pval : float or None
            p-value from the conditional independence test.
        dependent : bool or None
            Whether X and Y are found to be dependent given Z in the specified context.
        """
        array, xyz, XYZ, type_array = self.get_masked_data(
                                                X=X,
                                                Y=Y,
                                                Z=Z, 
                                                context_lags=context_lags,
                                                tau_max=tau_max)
        
        if array is None:
            return None, None, None

        x_array = array[xyz==0, :].T
        y_array = array[xyz==1, :].T
        z_array = array[xyz==2, :]
        if z_array.shape[0] == 0:
            z_array = None
        else:
            z_array = z_array.T

        x_type_array = type_array[xyz==0, :].T
        y_type_array = type_array[xyz==1, :].T
        z_type_array = type_array[xyz==2, :]
        if z_type_array.shape[0] == 0:
            z_type_array = None
        else:
            z_type_array = z_type_array.T

        if np.any(type_array == 1):
            val, pval, dependent = self.mixed_cit.run_test_raw(
                                x=x_array, y=y_array, z=z_array,
                                x_type=x_type_array, y_type=y_type_array, z_type=z_type_array,
                                alpha_or_thres=alpha_or_thresh)
        else:
            val, pval, dependent = self.cont_cit.run_test_raw(
                                x=x_array, y=y_array, z=z_array,
                                x_type=x_type_array, y_type=y_type_array, z_type=z_type_array,
                                alpha_or_thres=alpha_or_thresh)
        
        return val, pval, dependent


    def run_test(self, X, Y, Z, tau_max, alpha_or_thres=0.05):
        if self.use_mask:
            context_lags = self.check_context_in_Z(X, Y, Z)
            context_lags = list(set(context_lags))
            if len(context_lags) > 0:
                if self.check_context_in_X_Y(X=X, Y=Y) == False:
                    val, pval, dependent = self.run_context_cond_ind(
                            X=X, Y=Y,
                            Z=Z, context_lags=context_lags,
                            tau_max=tau_max,
                            alpha_or_thresh=alpha_or_thres)
                else:
                    val, pval, dependent = self.mixed_cit.run_test(
                        X=X, Y=Y,
                        Z=Z, tau_max=tau_max,
                        alpha_or_thres=alpha_or_thres)
            else:
                if self.check_context_in_X_Y(X=X, Y=Y):
                    val, pval, dependent = self.mixed_cit.run_test(
                            X=X, Y=Y,
                            Z=Z, tau_max=tau_max,
                            alpha_or_thres=alpha_or_thres)
                else: # run on pooled data
                    val, pval, dependent = self.cont_cit.run_test(
                                            X=X, Y=Y,
                                            Z=Z, tau_max=tau_max,
                                            alpha_or_thres=alpha_or_thres)
        else:
            if self.check_context_in_X_Y(X=X, Y=Y):
                    val, pval, dependent = self.mixed_cit.run_test(
                            X=X, Y=Y,
                            Z=Z, tau_max=tau_max,
                            alpha_or_thres=alpha_or_thres)
            else: # run on pooled data
                val, pval, dependent = self.cont_cit.run_test(
                                        X=X, Y=Y,
                                        Z=Z, tau_max=tau_max,
                                        alpha_or_thres=alpha_or_thres)
            
        return val, pval, dependent

    def __getattr__(self, name):
        if name == "mixed_cit":
            raise AttributeError("mixed_cit attribute not found")
        elif name == "cont_cit":
            raise AttributeError("cont_cit attribute not found")
        return getattr(self.mixed_cit, name)

