import numpy as np
from itertools import product

class MaskGenerator:
    def __init__(self, type='persistent'):
       self.type = type
       if self.type not in ['sparse', 'persistent', 'naive']:
           raise ValueError("Type must be either 'sparse', 'persistent' or 'naive'.")

    def generate_lagged_mask_sparse(self, data, context_idxs, context_lags, tau_max=0, values=None):
        if values is None:
            raise ValueError("Values must be provided for lagged context variables.")
        data_masks = np.zeros((1, data.shape[1]), dtype=bool)
        for t in range(np.abs(tau_max) + 1, data.shape[1] - np.abs(tau_max) - 1):
            for i in range(len(context_idxs)):
                if data[context_idxs[i], t - np.abs(context_lags[i])] != values[i]:
                    mask_value = True
                else:
                    mask_value = False
                data_masks[:, t] = mask_value
        return data_masks
    

    def generate_lagged_mask_persistent(self, data, context_idxs, tau_max=0, values=None):
        if values is None:
            raise ValueError("Values must be provided for lagged context variables.")
        data_masks = np.zeros((1, data.shape[1]), dtype=bool)
        for t in range(tau_max + 1, data.shape[1] - np.abs(tau_max) - 1):
            mask_value = False
            for i in range(len(context_idxs)):
                for tau in range(np.abs(tau_max) + 1):
                    if data[context_idxs[i], t - tau] != values[i]:
                        mask_value = True
            data_masks[:, t] = mask_value

        return data_masks


    def generate_naive_mask(self, data, context_idxs, tau_max=0, values=None):
        if values is None:
            raise ValueError("Values must be provided for context variables.")
        data_masks = np.zeros((1, data.shape[1]), dtype=bool)
        for t in range(data.shape[1]):
            for i in range(len(context_idxs)):
                if data[context_idxs[i], t] != values[i]:
                    mask_value = True
                else:
                    mask_value = False
                data_masks[:, t] = mask_value
        return data_masks