from endo_regime_pcmci.endo_cit import EndoCIT
from endo_regime_pcmci.mask_generator import MaskGenerator

import numpy as np

class PersistentEndoCIT(EndoCIT):
    def __init__(self, mixed_cit, cont_cit, context_vars, context_values=None):
        super().__init__(mixed_cit=mixed_cit, cont_cit=cont_cit, context_vars=context_vars, context_values=context_values)
        self.mask_generator = MaskGenerator('persistent')

    def check_context_in_Z(self, Z):
        """
        Checks if any of the context variables are present in Z.

        Args:
            Z (list): A list of tuples where each tuple's first element is compared to context variables.

        Returns:
            bool: True if any context variable is found as the first element in any item of Z, False otherwise.
        """
        for context_var in self.context_vars:
            for z in Z:
                if context_var == z[0]:
                    return True
        return False

    def check_context_in_X_Y(self, X, Y):
        """
        Checks if any of the context variables are present in X or Y.

        Args:
            X (list of tuple): List of tuples, where each tuple's first element is compared to context variables.
            Y (list of tuple): List of tuples, where each tuple's first element is compared to context variables.

        Returns:
            bool: True if any context variable matches the first element of any tuple in X or Y, False otherwise.
        """
        for context_var in self.context_vars:
            for x in X:
                if context_var == x[0]:
                    return True
        for context_var in self.context_vars:
            for y in Y:
                if context_var == y[0]:
                    return True
        return False
    

    def get_masked_data(self, X, Y, Z, tau_max=0):
        """
        Constructs and returns masked data arrays for conditional independence testing.
        Parameters
        ----------
        X : list
            List of variable names for the first variable set.
        Y : list
            List of variable names for the second variable set.
        Z : list
            List of variable names for the conditioning set.
        tau_max : int, optional
            Maximum time lag to consider for masking (default is 0).
        Returns
        -------
        masked_array : np.ndarray or None
            The masked data array with context variables removed. Returns None if the resulting array is too small.
        xyz : list or None
            List of variable names after masking. Returns None if the resulting array is too small.
        XYZ : list or None
            List of variable names and lags after masking. Returns None if the resulting array is too small.
        masked_type_array : np.ndarray or None
            The masked type array with context variables removed. Returns None if the resulting array is too small.
        Notes
        -----
        - Context variables specified in `self.context_vars` are identified in `Z` and removed from the data arrays.
        - Masking is performed using `self.mask_generator.generate_lagged_mask_persistent`.
        - If the masked array has fewer than 2 rows or 3 columns, all outputs are set to None.
        """
        array, xyz, XYZ, type_array = self.mixed_cit.dataframe.construct_array(X=X, Y=Y, Z=Z,
                                              tau_max=2 * tau_max,
                                              return_cleaned_xyz=True,
                                              do_checks=True,
                                              remove_overlaps=False,
                                              verbosity=self.verbosity)
        
        context_idxs = []
        current_values = []
        # already_added = []

        for j in range(len(self.context_vars)):
            for i in range(len(Z)): 
                if Z[i][0] == self.context_vars[j]:
                    context_idxs.append(i + len(X) + len(Y))
                    current_values.append(self.context_values[j])

        mask = self.mask_generator.generate_lagged_mask_persistent(
            data=array,
            context_idxs=context_idxs,
            tau_max=tau_max,
            values=current_values)
        
        sorted_idxs = sorted(context_idxs, reverse=True)
        masked_array = array[:, mask.flatten() == False]
        masked_type_array = type_array[:, mask.flatten() == False]

        for idx in sorted_idxs:
            masked_array = np.delete(masked_array, idx, axis=0)
            masked_type_array = np.delete(masked_type_array, idx, axis=0)

            del XYZ[-1][idx - (len(X) + len(Y))]
            
        xyz = xyz[:-len(sorted_idxs)]


        if masked_array.shape[0] < 2 or masked_array.shape[1] <= 2:
            return None, None, None, None

        return masked_array, xyz, XYZ, masked_type_array
    
    def run_context_cond_ind(self, X, Y, Z, tau_max=0, alpha_or_thresh=None):
        """
        Runs a conditional independence test between variables X and Y, conditioned on Z, 
        using either a mixed or continuous conditional independence tester depending on the data types.
        Parameters
        ----------
        X : array-like
            The first variable or set of variables to test.
        Y : array-like
            The second variable or set of variables to test.
        Z : array-like
            The conditioning variable(s).
        tau_max : int, optional
            The maximum time lag to consider when masking the data (default is 0).
        alpha_or_thresh : float or None, optional
            The significance level (alpha) or threshold to use for the conditional independence test.
        Returns
        -------
        val : float
            The test statistic value from the conditional independence test.
        pval : float
            The p-value associated with the test statistic.
        dependent : bool
            Whether the test indicates dependence (True) or independence (False) between X and Y given Z.
        Raises
        ------
        ValueError
            If the input array is None or has fewer than 8 columns.
        Notes
        -----
        The function selects the appropriate conditional independence test (mixed or continuous)
        based on the data types present in the input arrays.
        """
        array, xyz, XYZ, type_array = self.get_masked_data(
                                                X=X,
                                                Y=Y,
                                                Z=Z, 
                                               tau_max=tau_max)
        
        if array is None or array.shape[1] < 8:
            raise ValueError(f'Input array is None or has fewer than 8 columns: array.shape={None if array is None else array.shape}')

        x_array = array[xyz==0, :].T
        y_array = array[xyz==1, :].T
        z_array = array[xyz==2, :]
        if z_array.shape[0] == 0:
            z_array = None
        else:
            z_array = z_array.T

        x_type_array = type_array[xyz==0, :].T
        y_type_array = type_array[xyz==1, :].T
        z_type_array = type_array[xyz==2, :]
        if z_type_array.shape[0] == 0:
            z_type_array = None
        else:
            z_type_array = z_type_array.T

        if np.any(type_array == 1):
            val, pval, dependent = self.mixed_cit.run_test_raw(
                                x=x_array, y=y_array, z=z_array,
                                x_type=x_type_array, y_type=y_type_array, z_type=z_type_array,
                                alpha_or_thres=alpha_or_thresh)
        else:
            val, pval, dependent = self.cont_cit.run_test_raw(
                                x=x_array, y=y_array, z=z_array,
                                x_type=x_type_array, y_type=y_type_array, z_type=z_type_array,
                                alpha_or_thres=alpha_or_thresh)
        
        return val, pval, dependent

    def run_test(self, X, Y, Z, tau_max, alpha_or_thres=0.05):
        if self.use_mask:
            if self.check_context_in_Z(Z):
                if self.check_context_in_X_Y(X=X, Y=Y) == False:
                    val, pval, dependent = self.run_context_cond_ind(
                            X=X, Y=Y,
                            Z=Z, tau_max=tau_max,
                            alpha_or_thresh=alpha_or_thres)
                else:
                    # use pooled data and a mixed test
                    val, pval, dependent = self.mixed_cit.run_test(
                        X=X, Y=Y,
                        Z=Z, tau_max=tau_max,
                        alpha_or_thres=alpha_or_thres)
            else:
                if self.check_context_in_X_Y(X=X, Y=Y):
                    val, pval, dependent = self.mixed_cit.run_test(
                            X=X, Y=Y,
                            Z=Z, tau_max=tau_max,
                            alpha_or_thres=alpha_or_thres)
                else: # run on pooled data
                    val, pval, dependent = self.cont_cit.run_test(
                                            X=X, Y=Y,
                                            Z=Z, tau_max=tau_max,
                                            alpha_or_thres=alpha_or_thres)
        else:
            if self.check_context_in_X_Y(X=X, Y=Y):
                    val, pval, dependent = self.mixed_cit.run_test(
                            X=X, Y=Y,
                            Z=Z, tau_max=tau_max,
                            alpha_or_thres=alpha_or_thres)
            else: # run on pooled data
                val, pval, dependent = self.cont_cit.run_test(
                                        X=X, Y=Y,
                                        Z=Z, tau_max=tau_max,
                                        alpha_or_thres=alpha_or_thres)
                
            
        return val, pval, dependent