import torch
import numpy as np

class MaskSelecter:
    def __init__(self, mask_select_mode, neuron_dim=None,device=None):
        self.mask_select_method = mask_select_mode
        if neuron_dim is not None and device is not None:
            self.eye_tensor = torch.eye(neuron_dim,device=device).to(torch.bfloat16)
        self.mask_pass_num = 0
        self.mask_total = 0

    def _deck_select_mask(
        self,
        gradients,
        cluster_tensor,
        p,
        neighbor_p,
        neuron_dim,
        ):

        """
        The DECK method to select gradient mask for each parameter section in MLP.
        Args:

            gradients: the gradients of a parameter section in MLP, shape: (neuron_dim * embedding_size) or (embedding_size * neuron_dim)
            cluster_tensor: the one-hot cluster tensor for the parameter section in MLP, shape: (neuron_dim, cluster_num), each element is the cluster index
            p: top p of the clustered neurons' gradient will be given mask 1, else 0
            neighbor_p: the contribution rate of other neurons in cluster
            neuron_dim: the number of hidden_neurons in MLP, to distinguish up_proj and down_proj

        Returns:
            masks: the gradient mask for each parameter section in MLP, shape same gradients

        """
        # TODO: log necessary data for analysis (how the mask influence the tuning process)
        if gradients.shape[0] == neuron_dim:
            gradients = torch.transpose(gradients,0,1) # ensure shape (embedding_size, neuron_dim)
            transpose_flag = True
        elif gradients.shape[1] == neuron_dim:
            transpose_flag = False
        else:
            raise ValueError("The shape of gradients is not correct!")

        cluster_one_hot = cluster_tensor.to(device=gradients.device, dtype=torch.bfloat16) # shape: (neuron_dim, num_clusters)
        co_cluster = cluster_one_hot @ cluster_one_hot.T # shape: (neuron_dim, neuron_dim)
        co_cluster.fill_diagonal_(0) # remove self connection
        co_cluster = co_cluster / co_cluster.sum(dim=1) # normalize

        cluster_contribution = torch.abs(gradients) @ co_cluster # shape:(embed_size, num_clusters)
        score = cluster_contribution*neighbor_p + torch.abs(gradients)*(1-neighbor_p)

        _, flat_inds = torch.topk(score.view(-1), k=int(p*torch.numel(score)))
        masks = torch.zeros_like(score,dtype=torch.bfloat16)
        masks.view(-1)[flat_inds] = 1

        if transpose_flag:
            masks = torch.transpose(masks,0,1)
        return masks

    def _random_select_mask(self, gradients, p):
        """
        Randomly select a mask for the gradients, p is the probability of a neuron selected.
        Args:
            gradients: gradients of the model, shape: (neuron_dim * embedding_size) or (embedding_size * neuron_dim)
            p: the probability of a neuron selected, float
        Returns:
            mask: the mask for the gradients, shape: same as gradients
        """
        return (torch.rand_like(gradients, dtype=torch.bfloat16) < p)*1

    def _highest_select_mask(self, gradients, p):
        """
        Select the highest p% neurons for the gradients.
        Args:
            gradients: gradients of the model, shape: (neuron_dim * embedding_size) or (embedding_size * neuron_dim)
            p: the probability of a neuron selected, float
        Returns:
            mask: the mask for the gradients, shape: same as gradients
        """
        _, flat_inds = torch.topk(gradients.view(-1), int(gradients.numel() * p))
        mask = torch.zeros_like(gradients,dtype=torch.bfloat16)
        mask.view(-1)[flat_inds] = 1
        return mask

    def _deck_no_cluster_select_mask(
        self,
        gradients,
        neuron_activation_graph,
        p,
        neighbor_p,
        neuron_dim,
        ):
        """
        The DECK method to select gradient mask for each parameter section in MLP.
        Args:

            gradients: the gradients of a parameter section
            neuron_activation_graph: the projected coactivation graph of the MLP layer Tensor(neuron_dim, neuron_dim)
            p: the probability of selecting a gradient
        Returns: 
            mask: the mask of the selected gradients
        """
        if gradients.shape[0] == neuron_dim:
            gradients = torch.transpose(gradients,0,1) # ensure shape (embedding_size, neuron_dim)
            transpose_flag = True
        elif gradients.shape[1] == neuron_dim:
            transpose_flag = False
        else:
            raise ValueError("The shape of gradients is not correct!")

        neuron_activation_graph = torch.from_numpy(neuron_activation_graph).to(device=gradients.device,dtype=torch.bfloat16)
        normalized_graph = torch.nn.functional.normalize(neuron_activation_graph, p=1, dim=0)
        transformation = normalized_graph * neighbor_p +self.eye_tensor.to(device=gradients.device) * (1-neighbor_p)
        score = torch.abs(gradients) @ transformation

        # select top k 
        _, flat_inds = torch.topk(score.view(-1), k=int(p*torch.numel(score)))
        masks = torch.zeros_like(score,dtype=torch.bfloat16)
        masks.view(-1)[flat_inds] = 1 

        if transpose_flag:
            masks = torch.transpose(masks,0,1)

        return masks

    def select_mask(
        self, 
        gradients, 
        p, 
        neuron_activation_graph=None,
        cluster_tensor=None,
        neighbor_p=None,
        neuron_dim=14336
        ):
        mask = None
        if self.mask_select_method == "deck":
            mask = self._deck_select_mask(
                gradients,
                cluster_tensor,
                p,
                neighbor_p,
                neuron_dim,
                )
        elif self.mask_select_method == "deck_no_cluster":
            mask = self._deck_no_cluster_select_mask(
                gradients,
                neuron_activation_graph,
                p,
                neighbor_p,
                neuron_dim,
                )
        elif self.mask_select_method == "random":
            mask = self._random_select_mask(
                gradients,
                p,
                )
        elif self.mask_select_method == "highest" or self.mask_select_method == "gmt":
            mask = self._highest_select_mask(
                gradients,
                p,
                )
        elif self.mask_select_method == "normal":
            mask = torch.ones_like(gradients, device=gradients.device)
        else:
            raise ValueError("Invalid mask_select_method")
        if mask is not None:
            self.mask_pass_num += torch.sum(mask).item()
            self.mask_total += gradients.numel()
        return mask

    def reset_mask_pass_num(self):
        self.mask_pass_num = 0
        self.mask_total = 0
    
    def get_mask_pass_ratio(self):
        if self.mask_total == 0:
            return 0
        elif self.mask_select_method == "normal":
            return 1
        else:
            return self.mask_pass_num / self.mask_total
