import torch
from typing import Optional, Union, List, Callable, Tuple
from torch import nn
from typing import Dict, List, Tuple, Union
from contextlib import contextmanager
import math
import torch
from tqdm import tqdm
from functools import partial
import torch.nn.functional as F

class attr_state_manager():
    def __init__(self, model_name, model,tokenizer,single_compute_token=50,mlp_token_compute_num=None,mlp_softmax_temp=0.1,decoding_token_compute_num=50,mlp_softmax=True, ignore_special_token=True):
        self.model_name = model_name
        self.model = model
        self.tokenizer = tokenizer
        self.single_compute_token = single_compute_token
        self.mlp_token_compute_num = mlp_token_compute_num
        self.num_heads = self.model.config.num_attention_heads
        self.num_key_value_heads = self.model.config.num_key_value_heads
        self.num_layers = self.model.config.num_hidden_layers
        self.head_dim = self.model.config.hidden_size // self.num_heads
        self.hidden_dim = self.model.config.hidden_size
        self.mlp_softmax_temp = mlp_softmax_temp
        self.ignore_special_token = ignore_special_token
        self.decoding_token_compute_num = decoding_token_compute_num
        self.attention_weight_compute = compute_attention_weights_mha
        if mlp_softmax:
            self.mlp_attribute_compute = self.mlp_attribute_compute_softmax
        else:
            self.mlp_attribute_compute = self.mlp_attribute_compute_max
        if model_name == "llama":
            self.apply_rotary_pos_emb=apply_rotary_pos_emb_llama
        elif model_name == "qwen":
            self.apply_rotary_pos_emb=apply_rotary_pos_emb_qwen
        
    def get_states(self, prompt):
        """Get the states of the model for a given prompt.(Including the post attention states)
        Args:
            prompt (_str_): The input prompt for the model.
        """
        original_sdpa = F.scaled_dot_product_attention
        my_sdpa, get_captured = capture_sdpa()
        F.scaled_dot_product_attention = my_sdpa
        states, after_attn_list, hidden_states = get_stream_from_prompt(self.tokenizer, self.model, prompt)
        attribute_len = states[0][0].shape[0]
        F.scaled_dot_product_attention = original_sdpa
        masks, dropouts, causals = get_captured()
        self.states = states
        self.masks = masks
        self.dropouts = dropouts
        self.causals = causals
        self.attribute_len = attribute_len
    
    def get_all_layer_attribute_state(self,prompt):
        """
        Initializes the attribute state at all layers and get the corresponding last layer attribute state
        Args:
            prompt (str): The input prompt for the model.
        Returns:
            all_layer_attribute_state_list (list): A list of attribute states for all layers.
        """
        self.get_states(prompt)
        all_layer_attribute_state_list = []
        for start_layer_idx in range(self.num_layers+1):
            attribute_state = self.compute_attribute_state(start_layer_idx)
            all_layer_attribute_state_list.append(attribute_state)
        return all_layer_attribute_state_list

    def compute_attribute_state(self,start_layer_idx):
        """
        Token-Level Initialization at the given layer and compute the last layer attribute state.
        
        Args:
            start_layer_idx (int): The layer index for the Token-Level Initialization.
        Returns:
            attribute_state (torch.Tensor): The computed attribute state tensor with shape:
                - N: sequence length (number of tokens)
                - N: sequence length (number of tokens) 
                - D: hidden dimension size (model's hidden state dimension)
                Shape: (N, N, D)
        """
        masks = self.masks
        dropouts = self.dropouts
        causals = self.causals
        states = self.states
        attribute_len = self.attribute_len
        forward_num = math.ceil(self.attribute_len/self.single_compute_token)
        attribute_state_list = []
        with torch.no_grad():
            for cur_num in range(forward_num):
                # Initialization of attribute state at the given layer
                start_token_idx = cur_num * self.single_compute_token
                end_token_idx = min((cur_num + 1) * self.single_compute_token, attribute_len)
                cur_compute_token_len = end_token_idx - start_token_idx
                attribute_state = torch.zeros(cur_compute_token_len, attribute_len, self.model.config.hidden_size, dtype=next(self.model.parameters()).dtype,device=self.model.device)
                idx_range = torch.arange(start_token_idx, end_token_idx)
                local_idx = idx_range - start_token_idx
                attribute_state[local_idx, idx_range, :] = states[2*start_layer_idx][0][idx_range, :]
                if self.ignore_special_token and cur_num !=0 :
                    special_token_state = torch.zeros(1, attribute_state.shape[1], attribute_state.shape[2],dtype=next(self.model.parameters()).dtype,).to(self.model.device)
                    special_token_state[0,0,:] = states[2*start_layer_idx][0][0, :]
                    attribute_state = torch.cat([special_token_state, attribute_state], dim=0)
                # Compute the last layer attribute state
                for layer_idx in range(start_layer_idx, self.num_layers):
                    # Attn module computation
                    state = states[layer_idx*2]
                    attribute_state = self.attn_attribute_compute(layer_idx, state, masks, dropouts, causals, attribute_state)
                    # MLP module computation
                    state = states[layer_idx*2 + 1]
                    attribute_state = self.mlp_attribute_compute(layer_idx, state, attribute_state,self.mlp_token_compute_num,self.mlp_softmax_temp, ignore_special_token=self.ignore_special_token, cur_num=cur_num)
                if self.ignore_special_token and cur_num !=0 :
                    attribute_state = attribute_state[1:,:,:]
                # Last layernorm computation
                state=states[-1][0].to(attribute_state.device)
                ln = self.model.model.norm
                W_diag_elements = get_rmsnorm_scaling(ln, state)
                attribute_state = attribute_state * W_diag_elements
                attribute_state_list.append(attribute_state.clone().detach().cpu())
            attribute_state = torch.cat(attribute_state_list, dim=0)
            attribute_state=attribute_state.transpose(0, 1)
        return attribute_state
    
    def from_middle_layer_attribute_state(self, prompt, start_layer_idx):
        """
        Get states from prompt and compute the attribute state from the start layer index.
        """
        self.get_states(prompt)
        states = self.states
        attribute_state = self.compute_attribute_state(start_layer_idx)
        return attribute_state, states

    def get_last_layer_attribute_state(self, prompt):
        """
        Get states from prompt and compute the attribute state from the first layer.
        """
        self.get_states(prompt)
        states = self.states
        attribute_state = self.compute_attribute_state(0)
        return attribute_state, states
    
    def all_layer_to_init_attribute_state(self, prompt):
        """
        Token-Level Initialization at embedding layer and compute all intermediate layers' attribute state.
        Args:
            prompt (str): The input prompt for the model.

        Returns:
            attribute_state (torch.Tensor): The computed attribute state tensor with shape:
                - L: number of layers (number of layers in the model)
                - N: sequence length (number of tokens)
                - N: sequence length (number of tokens) 
                - D: hidden dimension size (model's hidden state dimension)
                Shape: (L, N, N, D)
        """
        self.get_states(prompt)
        masks = self.masks
        dropouts = self.dropouts
        causals = self.causals
        states = self.states
        attribute_len = self.attribute_len
        forward_num = math.ceil(self.attribute_len/self.single_compute_token)
        attribute_state_list = []
        with torch.no_grad():
            for cur_num in range(forward_num):
                start_token_idx = cur_num * self.single_compute_token
                end_token_idx = min((cur_num + 1) * self.single_compute_token, attribute_len)
                cur_compute_token_len = end_token_idx - start_token_idx
                attribute_state = torch.zeros(cur_compute_token_len, attribute_len, self.model.config.hidden_size,dtype=next(self.model.parameters()).dtype).to(self.model.device)
                idx_range = torch.arange(start_token_idx, end_token_idx)
                local_idx = idx_range - start_token_idx
                attribute_state[local_idx, idx_range, :] = states[0][0][idx_range, :]
                if self.ignore_special_token and cur_num !=0 :
                    special_token_state = torch.zeros(1, attribute_state.shape[1], attribute_state.shape[2],dtype=next(self.model.parameters()).dtype).to(self.model.device)
                    special_token_state[0,0,:] = states[0][0][0, :]
                    attribute_state = torch.cat([special_token_state, attribute_state], dim=0)
                all_layer_attribute_state_list = []
                for layer_idx in range(self.num_layers):
                    state = states[layer_idx*2]
                    attribute_state = self.attn_attribute_compute(layer_idx, state, masks, dropouts, causals, attribute_state)
                    state = states[layer_idx*2 + 1]
                    attribute_state = self.mlp_attribute_compute(layer_idx, state, attribute_state,self.mlp_token_compute_num, self.mlp_softmax_temp, ignore_special_token=self.ignore_special_token, cur_num=cur_num)
                    if self.ignore_special_token and cur_num !=0 :
                        all_layer_attribute_state_list.append(attribute_state[1:,:,:].clone().detach().cpu())
                    else:
                        all_layer_attribute_state_list.append(attribute_state.clone().detach().cpu())
                attribute_state_list.append(torch.stack(all_layer_attribute_state_list, dim=0))
            attribute_state = torch.cat(attribute_state_list, dim=1)
            attribute_state=attribute_state.transpose(1, 2)
        return attribute_state, states

    def get_mlp_neuron_attribution_state(self, prompt, start_layer_idx, compute_num=None):
        """
        MLP Neuron Level DePass.
        Initializes the attribute state at the given layer and computes the last layer attribute state.

        Args:
            prompt (str): The input prompt for the model.
            start_layer_idx (int): The layer index for the MLP neuron-Level Initialization.
            compute_num (int, optional): The number of computation. If None, it will be set to 1.

        Returns:
            attribute_state (torch.Tensor): The computed attribute state tensor with shape:
                - N: sequence length (number of tokens)
                - M: MLP neuron number (number of neurons in the MLP layer) 
                - D: hidden dimension size (model's hidden state dimension)
                Shape: (N, M, D)
        """
        self.get_states(prompt)
        masks = self.masks
        dropouts = self.dropouts
        causals = self.causals
        states = self.states
        attribute_len = self.attribute_len
        attribute_state_list = []
        with torch.no_grad():
            # Divide the intermediate size by the number of splits
            intermediate_size = self.model.config.intermediate_size    
            if compute_num is None:
                num_splits = 1
            else:
                num_splits = compute_num
            chunk_size = math.ceil(intermediate_size / num_splits)
            for cur_neuron_num in range(num_splits):
                # Initialization of attribute state at the given layer
                start_neuron_idx = cur_neuron_num * chunk_size
                end_neuron_idx = min((cur_neuron_num + 1) * chunk_size, intermediate_size)
                attribute_state = torch.zeros(end_neuron_idx - start_neuron_idx + 1, attribute_len, self.model.config.hidden_size,dtype=next(self.model.parameters()).dtype).to(self.model.device)
                layer = self.model.model.layers[start_layer_idx]
                target_device = attribute_state.device
                state_ori = states[2*start_layer_idx+1][0].to(target_device)
                state_after_mlp = states[2*start_layer_idx+2][0].to(target_device)
                x = layer.post_attention_layernorm(state_ori)
                mlp = layer.mlp
                gate_proj = mlp.gate_proj(x)
                up_proj = mlp.up_proj(x)
                intermediate = mlp.act_fn(gate_proj) * up_proj 
                intermediate = intermediate.squeeze(0)
                down_weights = mlp.down_proj.weight
                expanded_down_weights = down_weights.T.unsqueeze(0).expand(intermediate.size(0), -1, -1)
                mlp_output = intermediate.unsqueeze(-1) * expanded_down_weights
                mlp_decompose = mlp_output[:,start_neuron_idx: end_neuron_idx,:].to(target_device)
                attribute_state[0,:,:] = state_ori - mlp_decompose.sum(1)
                attribute_state[1:,:,:] = mlp_decompose.transpose(0, 1)
                # Compute the last layer attribute state
                for layer_idx in range(start_layer_idx + 1, self.num_layers):
                    # Attn module computation
                    state = states[layer_idx*2]
                    attribute_state = self.attn_attribute_compute(layer_idx, state, masks, dropouts, causals, attribute_state)
                    # MLP module computation
                    state = states[layer_idx*2 + 1]
                    attribute_state = self.mlp_attribute_compute(layer_idx, state, attribute_state,self.mlp_token_compute_num,self.mlp_softmax_temp, ignore_special_token=False, cur_num=0)
                # Last layernorm computation
                if cur_neuron_num !=0 :
                    attribute_state = attribute_state[1:,:,:]
                state=states[-1][0].to(attribute_state.device)
                ln = self.model.model.norm
                W_diag_elements = get_rmsnorm_scaling(ln, state)
                attribute_state = attribute_state * W_diag_elements
                attribute_state_list.append(attribute_state.clone().detach().cpu())
            attribute_state = torch.cat(attribute_state_list, dim=0)
            attribute_state=attribute_state.transpose(0, 1)
        return attribute_state
    
    def get_layer_module_attribution_state(self, prompt, start_layer_idx, type=None):
        """
        Module-Level DePass.
        Initializes the attribute state at the given layer and computes the last layer attribute state.

        Args:
            prompt (str): The input prompt for the model.
            start_layer_idx (int): The layer index for the Module-Level Initialization.
            type (str, optional): Module type. It can be "attn" or "mlp" or "attn_head".

        Returns:
            attribute_state (torch.Tensor): The computed attribute state tensor with shape:
                - N: sequence length (number of tokens)
                - M: module number (number of modules in the layer) 
                - D: hidden dimension size (model's hidden state dimension)
                Shape: (N, M, D)
        """
        self.get_states(prompt)
        masks = self.masks
        dropouts = self.dropouts
        causals = self.causals
        states = self.states
        attribute_len = self.attribute_len
        forward_num = math.ceil(self.attribute_len/self.single_compute_token)
        attribute_state_list = []
        with torch.no_grad():
            for cur_num in range(forward_num):
                start_token_idx = cur_num * self.single_compute_token
                end_token_idx = min((cur_num + 1) * self.single_compute_token, attribute_len)
                attribute_state = torch.zeros(2, attribute_len, self.model.config.hidden_size,dtype=next(self.model.parameters()).dtype,).to(self.model.device,)
                layer = self.model.model.layers[start_layer_idx]
                # Attention module level initialization
                if type == "attn":
                    target_device = attribute_state.device
                    state_ori = states[2*start_layer_idx][0].to(target_device)
                    state_after_attn = states[2*start_layer_idx+1][0].to(target_device)
                    attribute_state[0,:,:]= state_ori
                    attribute_state[1,:,:]= state_after_attn-state_ori
                    state = states[start_layer_idx*2 + 1]
                    attribute_state = self.mlp_attribute_compute(start_layer_idx, state, attribute_state,self.mlp_token_compute_num,self.mlp_softmax_temp, ignore_special_token=False, cur_num=cur_num)
                # MLP module level initialization 
                elif type == "mlp":
                    target_device = attribute_state.device
                    state_ori = states[2*start_layer_idx+1][0].to(target_device)
                    state_after_mlp = states[2*start_layer_idx+2][0].to(target_device)
                    attribute_state[0,:,:]= state_ori
                    attribute_state[1,:,:]= state_after_mlp-state_ori
                # Attention head level initialization
                elif type == "attn_head":
                    attribute_state = torch.zeros(self.num_heads + 1, attribute_len, self.model.config.hidden_size, dtype=next(self.model.parameters()).dtype).to(self.model.device)
                    target_device = attribute_state.device
                    state_ori = states[2*start_layer_idx][0].to(target_device)
                    state_after_attn = states[2*start_layer_idx+1][0].to(target_device)
                    state = states[2*start_layer_idx]
                    attn_output = self.attn_head_contribution_compute(start_layer_idx, state, masks, dropouts, causals, attribute_state).to(target_device)
                    attribute_state[0,:,:]= state_ori
                    attribute_state[1:,:,:]= attn_output
                    attribute_state = self.mlp_attribute_compute(start_layer_idx, state, attribute_state,self.mlp_token_compute_num,self.mlp_softmax_temp, ignore_special_token=False, cur_num=cur_num)
                else:
                    raise ValueError("type must be 'attn' or 'mlp' or 'attn_head'")
                # Compute the last layer attribute state
                for layer_idx in range(start_layer_idx + 1, self.num_layers):
                    # Attn module computation
                    state = states[layer_idx*2]
                    attribute_state = self.attn_attribute_compute(layer_idx, state, masks, dropouts, causals, attribute_state)
                    # MLP module computation
                    state = states[layer_idx*2 + 1]
                    attribute_state = self.mlp_attribute_compute(layer_idx, state, attribute_state,self.mlp_token_compute_num,self.mlp_softmax_temp, ignore_special_token=False, cur_num=cur_num)
                if self.ignore_special_token and cur_num !=0 :
                    attribute_state = attribute_state[1:,:,:]
                # Last layernorm computation
                state = states[-1][0].to(attribute_state.device)
                ln = self.model.model.norm
                W_diag_elements = get_rmsnorm_scaling(ln, state)
                attribute_state = attribute_state * W_diag_elements
                attribute_state_list.append(attribute_state.clone().detach().cpu())
            attribute_state = torch.cat(attribute_state_list, dim=0)
            attribute_state=attribute_state.transpose(0, 1)
        return attribute_state

        
    def get_subspace_attribute_state(self,prompt,start_layer_idx,attribute_state):
        """
        Subspace-Level DePass.
        Takes externally initialized attribute state at a specified layer and computes the final layer's attribution state.

        Args:
            prompt (str): The input prompt for the model.
            start_layer_idx (int): The layer index for the Subspace-Level Initialization.
            attribute_state (torch.Tensor shape: (M, N, D)): The attribute state tensor to be used for the computation.

        Returns:
            attribute_state (torch.Tensor): The computed attribute state tensor with shape:
                - N: sequence length (number of tokens)
                - M: subspace number (number of subspaces in the layer) 
                - D: hidden dimension size (model's hidden state dimension)
                Shape: (N, M, D)
        """
        self.get_states(prompt)
        masks = self.masks
        dropouts = self.dropouts
        causals = self.causals
        states = self.states
        attribute_state_list = []
        with torch.no_grad():
            attribute_state = attribute_state.to(self.model.device)
            for layer_idx in range(start_layer_idx + 1, self.num_layers):
                state = states[layer_idx*2]
                attribute_state = self.attn_attribute_compute(layer_idx, state, masks, dropouts, causals, attribute_state)
                state = states[layer_idx*2 + 1]
                attribute_state = self.mlp_attribute_compute(layer_idx, state, attribute_state,self.mlp_token_compute_num,self.mlp_softmax_temp, ignore_special_token=False, cur_num=0)
            state = states[-1][0].to(attribute_state.device)
            ln = self.model.model.norm
            W_diag_elements = get_rmsnorm_scaling(ln, state)
            attribute_state = attribute_state * W_diag_elements
            attribute_state_list.append(attribute_state.clone().detach().cpu())
            attribute_state = torch.cat(attribute_state_list, dim=0)
            attribute_state=attribute_state.transpose(0, 1)
        return attribute_state
    
    def model_generate_cite(
            self, 
            prompt,
            inputs: Optional[torch.Tensor] = None,
            max_length: Optional[int] = None,
            min_length: Optional[int] = None,
            do_sample: Optional[bool] = None,
            num_beams: Optional[int] = None,
            temperature: Optional[float] = None,
            top_k: Optional[int] = None,
            top_p: Optional[float] = None,
            typical_p: Optional[float] = None,
            repetition_penalty: Optional[float] = None,
            bad_words_ids: Optional[List[int]] = None,
            force_words_ids: Optional[Union[List[int], List[List[int]]]] = None,
            bos_token_id: Optional[int] = None,
            pad_token_id: Optional[int] = None,
            eos_token_id: Optional[int] = None,
            length_penalty: Optional[float] = None,
            no_repeat_ngram_size: Optional[int] = None,
            encoder_no_repeat_ngram_size: Optional[int] = None,
            num_return_sequences: Optional[int] = None, 
        ):
        """
        Token-Level DePass.
        Generates text from the model using the provided prompt and use DePass to compute the attribute state for new tokens.

        Args:
            prompt (str): The input prompt for the model.
            inputs (torch.Tensor, optional): The input tensor for the model. Defaults to None.
            max_length (int, optional): The maximum length of the generated text. Defaults to None.
            min_length (int, optional): The minimum length of the generated text. Defaults to None.
            do_sample (bool, optional): Whether to use sampling. Defaults to None.
            num_beams (int, optional): The number of beams for beam search. Defaults to None.
            temperature (float, optional): The temperature for sampling. Defaults to None.
            top_k (int, optional): The number of top-k tokens to sample from. Defaults to None.
            top_p (float, optional): The cumulative probability for nucleus sampling. Defaults to None.
            typical_p (float, optional): The typical probability for sampling. Defaults to None.
            repetition_penalty (float, optional): The penalty for repeated tokens. Defaults to None.
            bad_words_ids (List[int], optional): A list of token IDs to avoid. Defaults to None.
            force_words_ids (Union[List[int], List[List[int]]], optional): A list of token IDs to force. Defaults to None.
            bos_token_id (int, optional): The beginning of sequence token ID. Defaults to None.
            pad_token_id (int, optional): The padding token ID. Defaults to None.
            eos_token_id (int, optional): The end of sequence token ID. Defaults to None.
            length_penalty (float, optional): The length penalty for beam search. Defaults to None.
            no_repeat_ngram_size (int, optional): The size of n-grams to avoid repeating. Defaults to None.
            encoder_no_repeat_ngram_size (int, optional): The size of n-grams to avoid repeating in the encoder. Defaults to None.
            num_return_sequences (int, optional): The number of return sequences. Defaults to None.
        

        Returns:
            generated_text (str): The generated text from the model.
            score_list (list): A list of dictionaries containing the token and its corresponding attribute score.
            attribute_state (torch.Tensor): The computed attribute state tensor with shape:
                - N: sequence length (number of tokens)
                - N: sequence length (number of tokens) 
                - D: hidden dimension size (model's hidden state dimension)
                Shape: (N, N, D)
            states (list): A list of tensors containing the states of the model.
        """
        # Get the generated text from the model
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        assert num_return_sequences == 1 or num_return_sequences == None, "num_return_sequences must be 1"
        initial_input_length = inputs["input_ids"].shape[1]
        outputs = self.model.generate(
            inputs["input_ids"],
            max_length=max_length,
            min_length=min_length,
            do_sample=do_sample,
            num_beams=num_beams,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            typical_p=typical_p,
            repetition_penalty=repetition_penalty,
            bad_words_ids=bad_words_ids,
            force_words_ids=force_words_ids,
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
            num_return_sequences=num_return_sequences
        )
        output_length= outputs.shape[1]
        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        attribute_state, states = self.get_last_layer_attribute_state(generated_text)
        # Compute the attribute score for the generated tokens
        score_list = []
        for i in range(initial_input_length,output_length):
            dict={}
            token_id= outputs[0][i]
            token = self.tokenizer.decode(token_id)
            attribute_score = self.compute_attribute_score(attribute_state, i-1, token_id)
            dict["token"]=token
            dict["attribute_score"]=attribute_score
            score_list.append(dict)
        return generated_text, score_list, attribute_state, states
        
    def compute_attribute_score(self,attribute_state,token_idx,decode_token_id):
        """
        Given the attribute state, token index, and decoding token ID, compute the attribute score.

        Args:
            attribute_state (torch.Tensor): The attribute state tensor with shape: N, M, D
                - N: sequence length (number of tokens)
                - M: DePass component number
                - D: hidden dimension size (model's hidden state dimension)

        Returns:
            attribute_score (torch.Tensor): The computed attribute score tensor with shape: M
                - M: DePass component number
        """
        lm_head=self.model.lm_head
        attribute_state = attribute_state[token_idx].to(self.model.device)
        attribute_logits = lm_head(attribute_state.to(self.model.device))
        attribute_score=attribute_logits[:,decode_token_id]
        return attribute_score
    
    
    def mlp_attribute_compute_max(self, layer_idx, state, attribute_state, mlp_token_compute_num = None,mlp_softmax_temp = 0.1, ignore_special_token=True):
        """
        Decomposed MLP module computation.
        Neuron Distributed to the maxinum contributing component.
        
        Args:
            layer_idx (int): The layer index for the MLP module.
            state (torch.Tensor): The input state tensor with shape: N, D
                - N: sequence length (number of tokens)
                - D: hidden dimension size (model's hidden state dimension)
            attribute_state (torch.Tensor): The attribute state tensor with shape: N, M, D
                - N: sequence length (number of tokens)
                - M: DePass component number
                - D: hidden dimension size (model's hidden state dimension)
            mlp_token_compute_num (int, optional): The number of tokens to compute. Defaults to None.
            mlp_softmax_temp (float, optional): The temperature for softmax. Defaults to 0.1.
            ignore_special_token (bool, optional): Whether to ignore special tokens. Defaults to True.
            
        Returns:
            attribute_state (torch.Tensor): The updated attribute state tensor with shape: N, M, D
                - N: sequence length (number of tokens)
                - M: DePass component number
                - D: hidden dimension size (model's hidden state dimension)
        """
        layer = self.model.model.layers[layer_idx]
        ln = layer.post_attention_layernorm
        gate_proj = layer.mlp.gate_proj
        state_norm = ln(state)
        W_diag_elements = get_rmsnorm_scaling(ln, state)
        attribute_state_norm = attribute_state * W_diag_elements
        gate_ratio = gate_proj(attribute_state_norm).transpose(-2, -1)
        ori_gate = gate_proj(state_norm).transpose(-2, -1)
        gate_ratio = torch.cat([gate_ratio, ori_gate - gate_ratio.sum(0)])
        max_values, _ = torch.max(gate_ratio[1:,:,:], dim=0, keepdim=True)
        gate_ratio = torch.cat([
            torch.zeros_like(gate_ratio[:1,:,:]),
            (gate_ratio[1:,:,:] == max_values).float()
        ])
        attribute_mlp_values = torch.zeros_like(attribute_state)
        if mlp_token_compute_num is None:
            mlp_token_compute_num = state_norm.shape[1]
        num_batches = (state_norm.shape[1] + mlp_token_compute_num - 1) // mlp_token_compute_num
        
        for i in range(num_batches):
            start_idx = i * mlp_token_compute_num
            end_idx = min((i + 1) * mlp_token_compute_num, state_norm.shape[1])
            mlp_slice = self.get_per_ffn2_values(state_norm[:,start_idx:end_idx,:], layer_idx).squeeze(0)
            gate_slice = gate_ratio[:-1][:, :, start_idx:end_idx]
            attribute_mlp_values[:, start_idx:end_idx, :] += torch.einsum("qni, knq -> kqi", mlp_slice, gate_slice)
        attribute_state += attribute_mlp_values
        return attribute_state
        
    def mlp_attribute_compute_softmax(self, layer_idx, state, attribute_state, mlp_token_compute_num = None, mlp_softmax_temp=0.1, ignore_special_token=True, cur_num=None):
        """
        Decomposed MLP module computation.
        Neurons Distributed based on softmax contributing coefficients.
        Args:
            layer_idx (int): The layer index for the MLP module.
            state (torch.Tensor): The input state tensor with shape: N, D
                - N: sequence length (number of tokens)
                - D: hidden dimension size (model's hidden state dimension)
            attribute_state (torch.Tensor): The attribute state tensor with shape: N, M, D
                - N: sequence length (number of tokens)
                - M: DePass component number
                - D: hidden dimension size (model's hidden state dimension)
            mlp_token_compute_num (int, optional): The number of tokens to compute. Defaults to None.
            mlp_softmax_temp (float, optional): The temperature for softmax. Defaults to 0.1.
            ignore_special_token (bool, optional): Whether to ignore special tokens. Defaults to True.
            cur_num (int, optional): The current token index. Defaults to None.
        Returns:
            attribute_state (torch.Tensor): The updated attribute state tensor with shape: N, M, D
                - N: sequence length (number of tokens)
                - M: DePass component number
                - D: hidden dimension size (model's hidden state dimension)
        """
        target_device = attribute_state.device
        layer = self.model.model.layers[layer_idx]
        ln = layer.post_attention_layernorm
        gate_proj = layer.mlp.gate_proj
        state = state.to(target_device)
        state_norm = ln(state)
        W_diag_elements = get_rmsnorm_scaling(ln, state).to(target_device)
        attribute_state_norm = attribute_state * W_diag_elements
        gate_ratio = gate_proj(attribute_state_norm).transpose(-2, -1)
        ori_gate = gate_proj(state_norm).transpose(-2, -1)
        gate_ratio = torch.cat([gate_ratio, ori_gate - gate_ratio.sum(0)])
        # Process the speacial token
        if ignore_special_token:
            last_minus_first = gate_ratio[-1:]
            gate_ratio = torch.cat([
                torch.zeros_like(gate_ratio[:1]),
                gate_ratio[1:-1],          
                last_minus_first  
            ], dim=0)
        elif ignore_special_token and cur_num == None:
            raise ValueError("ignore_special_token is True, but cur_num is None. Provide cur_num.")
        gate_ratio[0,:,0] = 1.0
        gate_ratio[-1,:,:] = torch.where(torch.abs(gate_ratio[-1,:,:]) < 1e-5, torch.zeros_like(gate_ratio[-1,:,:]), gate_ratio[-1,:,:])
        zero_cols = (gate_ratio[:,:,:] == 0).all(dim=0)
        gate_ratio[-1,:,:] = torch.where(zero_cols, torch.ones_like(gate_ratio[-1,:,:]),gate_ratio[-1,:,:])
        mask = (gate_ratio != 0).float()
        gate_ratio.masked_fill_(mask == 0, float('-inf'))
        gate_ratio = torch.softmax(gate_ratio / mlp_softmax_temp, dim=0) * mask
        attribute_mlp_values = torch.zeros_like(attribute_state)
        if mlp_token_compute_num is None:
            mlp_token_compute_num = state_norm.shape[1]
        num_batches = (state_norm.shape[1] + mlp_token_compute_num - 1) // mlp_token_compute_num
        
        for i in range(num_batches):
            start_idx = i * mlp_token_compute_num
            end_idx = min((i + 1) * mlp_token_compute_num, state_norm.shape[1])
            mlp_gate = self.get_per_ffn2_values(state_norm[:,start_idx:end_idx,:], layer_idx).to(target_device)
            gate_slice = gate_ratio[:-1][:, :, start_idx:end_idx].to(target_device)
            mlp_gate = mlp_gate.squeeze(0)
            gate_slice_trans = gate_slice.permute(0, 2, 1)
            weight = layer.mlp.down_proj.weight.to(target_device)
            mlp_gate_expanded = mlp_gate.unsqueeze(0).expand(gate_slice_trans.size(0), -1, -1) 
            mlp_gate = mlp_gate_expanded * gate_slice_trans
            k, n, i = mlp_gate.shape
            d = weight.size(0)
            mlp_gate = mlp_gate.view(k * n, i)
            target_dtype = weight.dtype
            mlp_gate = mlp_gate.to(dtype=target_dtype)
            output = mlp_gate @ weight.T 
            output = output.view(k, n, d)
            attribute_mlp_values[:, start_idx:end_idx, :] += output
        attribute_state += attribute_mlp_values
        return attribute_state
    
    
    def attn_attribute_compute(self,layer_idx,state,masks,dropouts,causals,attribute_state):
        """
        Decomposed Attention module computation.
        Args:
            layer_idx (int): The layer index for the Attention module.
            state (torch.Tensor): The input state tensor with shape: N, D
                - N: sequence length (number of tokens)
                - D: hidden dimension size (model's hidden state dimension)
            masks (list): A list of attention masks for each layer.
            dropouts (list): A list of dropout probabilities for each layer.
            causals (list): A list of causal flags for each layer.
            attribute_state (torch.Tensor): The attribute state tensor with shape: N, M, D
                - N: sequence length (number of tokens)
                - M: DePass component number
                - D: hidden dimension size (model's hidden state dimension)
        Returns:
            attribute_state (torch.Tensor): The updated attribute state tensor with shape: N, M, D
                - N: sequence length (number of tokens)
                - M: DePass component number
                - D: hidden dimension size (model's hidden state dimension)
        """
        layer = self.model.model.layers[layer_idx]
        v_proj = layer.self_attn.v_proj
        o_proj = layer.self_attn.o_proj
        ln = layer.input_layernorm
        state_norm = ln(state)
        attn_weight = self.attention_weight_compute(state_norm, layer.self_attn, layer_idx, masks[layer_idx], dropouts[layer_idx], causals[layer_idx],self.apply_rotary_pos_emb,self.model_name)
        W_diag_elements = get_rmsnorm_scaling(ln, state)
        attribute_state_norm = attribute_state * W_diag_elements
        target_device = attribute_state.device
        v_proj_weight = v_proj.weight.view(self.num_key_value_heads, self.head_dim, self.hidden_dim).to(target_device)
        o_proj_weight = o_proj.weight.view(self.hidden_dim, self.num_heads, self.head_dim).to(target_device)
        attn_weight = attn_weight[0].to(target_device)
        attribute_values = torch.einsum("kqi,nhi->knqh", attribute_state_norm, v_proj_weight)
        attribute_values = repeat_kv(attribute_values, self.num_heads // self.num_key_value_heads)
        vo_attribute = torch.einsum("knqh,jnh->kqnj", attribute_values, o_proj_weight)
        attribute_state += torch.einsum("iknd, nqk -> iqd", vo_attribute, attn_weight)
        return attribute_state
    
    def get_per_ffn2_values(self, state, layer_idx):
        # Get neuron-level values for the MLP module
        layer = self.model.model.layers[layer_idx].mlp
        gate_up_output = layer.act_fn(layer.gate_proj(state)) * layer.up_proj(state)
        return gate_up_output
 
    def attn_head_contribution_compute(self,layer_idx,state,masks,dropouts,causals,attribute_state):
        # Get head-level values for the Attention module
        target_device = attribute_state.device
        layer = self.model.model.layers[layer_idx]
        v_proj = layer.self_attn.v_proj
        o_proj = layer.self_attn.o_proj
        ln = layer.input_layernorm
        state_norm = ln(state).to(v_proj.weight.device)
        attn_weight = self.attention_weight_compute(state_norm, layer.self_attn, layer_idx, masks[layer_idx], dropouts[layer_idx], causals[layer_idx],self.apply_rotary_pos_emb,self.model_name)
        v_proj_weight = v_proj.weight.view(self.num_key_value_heads, self.head_dim, self.hidden_dim).to(target_device)
        o_proj_weight = o_proj.weight.view(self.hidden_dim, self.num_heads, self.head_dim).to(target_device)
        attn_weight = attn_weight[0].to(target_device)
        state_norm = state_norm.to(target_device)
        attribute_values = torch.einsum("kqi,nhi->knqh", state_norm, v_proj_weight)
        attribute_values = repeat_kv(attribute_values, self.num_heads // self.num_key_value_heads)
        vo_attribute = torch.einsum("knqh,jnh->kqnj", attribute_values, o_proj_weight)
        attn_output = torch.einsum("iknd, nqk -> inqd", vo_attribute, attn_weight)[0]
        return attn_output
    
def capture_sdpa():
    # Capture the scaled dot product attention function
    original_sdpa = F.scaled_dot_product_attention
    masks, dropouts, causals = [], [], [] 
    def my_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
        masks.append(attn_mask)
        dropouts.append(dropout_p)
        causals.append(is_causal)

        return original_sdpa(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)

    return my_sdpa, lambda: (masks, dropouts, causals)

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:    
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def get_rmsnorm_scaling(norm_layer, hidden_states):
    """
    Compute the scaling factors for RMSNorm.
    """
    weight = norm_layer.weight.to(hidden_states.device)
    variance = hidden_states.pow(2).mean(dim=-1, keepdim=True)
    scale_factor = torch.rsqrt(variance + norm_layer.variance_epsilon)
    return weight * scale_factor

def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb_llama(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def apply_rotary_pos_emb_qwen(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def compute_attention_weights_mha(state_norm, attn_layer, layer_idx, attn_mask, dropout_p, is_causal, apply_rotary_pos_emb,model_name):
    q_proj, k_proj = attn_layer.q_proj, attn_layer.k_proj
    query_states = q_proj(state_norm)
    key_states = k_proj(state_norm)    
    bsz, seq_len, hidden_dim = query_states.shape
    num_heads = attn_layer.num_heads
    head_dim = hidden_dim // num_heads
    position_ids = torch.arange(seq_len, dtype=torch.long, device=state_norm.device).unsqueeze(0)
    query_states = query_states.view(bsz, seq_len, attn_layer.num_heads, attn_layer.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, seq_len, attn_layer.num_key_value_heads, attn_layer.head_dim).transpose(1, 2)
    if model_name == "llama":
        cos, sin = attn_layer.rotary_emb(state_norm, position_ids)
    elif model_name == "qwen":
        cos, sin = attn_layer.rotary_emb(state_norm, seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    key_states = repeat_kv(key_states, attn_layer.num_heads // attn_layer.num_key_value_heads)
    _, _, q_len, _ = query_states.shape
    L, S = query_states.size(-2), key_states.size(-2)
    scale_factor = 1 / math.sqrt(query_states.size(-1))
    attn_bias = torch.zeros(L, S, dtype=query_states.dtype, device=query_states.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool, device=query_states.device).tril(diagonal=0)
        attn_bias.masked_fill_(~temp_mask, float("-inf"))
        attn_bias = attn_bias.to(query_states.dtype)
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(~attn_mask, float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias
    attn_weight = query_states @ key_states.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight

        
@contextmanager
def hook_manager(model: nn.Module):

    block_inputs: Dict[int, torch.Tensor] = {}
    attn_intermediates: Dict[int, torch.Tensor] = {}
    final_outputs: Dict[int, torch.Tensor] = {}
    hook_handles: List[torch.utils.hooks.RemovableHandle] = []

    def get_block_pre_hook(layer_idx: int):
        def hook(module: nn.Module, inputs: Tuple[torch.Tensor]):
            block_inputs[layer_idx] = inputs[0]
        return hook

    def get_self_attn_hook(layer_idx: int):
        def hook(module: nn.Module, inputs: Tuple[torch.Tensor], output: Union[torch.Tensor, Tuple[torch.Tensor, ...]]):
            attn_out = output[0] if isinstance(output, tuple) else output
            device = attn_out.device
            default_zeros = torch.zeros_like(attn_out, device=device)
            block_input = block_inputs.get(layer_idx, default_zeros)
            if block_input.device != device:
                block_input = block_input.to(device)
            attn_intermediates[layer_idx] = block_input + attn_out.clone()
        return hook

    def get_block_post_hook(layer_idx: int):
        def hook(module: nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor):
            final_outputs[layer_idx] = output
        return hook

    for idx, block in enumerate(model.model.layers):
        hook_handles.append(block.register_forward_pre_hook(get_block_pre_hook(idx)))
        hook_handles.append(block.self_attn.register_forward_hook(get_self_attn_hook(idx)))
        hook_handles.append(block.register_forward_hook(get_block_post_hook(idx)))

    try:
        yield block_inputs, attn_intermediates, final_outputs
    finally:
        for h in hook_handles:
            h.remove()

def get_stream_from_prompt(tokenizer, model, prompt: str) -> List:
    input_ids = tokenizer.encode(prompt)
    input_ts = torch.tensor(input_ids, dtype=torch.int64, device=model.device).unsqueeze(0)
    pre_norm_hidden = {}
    def capture_pre_norm_hook(module, input, output):
        pre_norm_hidden['last'] = input[0].detach()
    handle = model.model.norm.register_forward_hook(capture_pre_norm_hook)
    with hook_manager(model) as (block_inputs, attn_intermediates, final_outputs):
        outputs = model(input_ts, output_hidden_states=True)
    handle.remove()
    hidden_states = outputs.hidden_states
    attn_intermediate_list = [attn_intermediates[i] for i in sorted(attn_intermediates.keys())]
    stream = []
    for i in range(len(attn_intermediate_list)):
        stream.append(hidden_states[i])
        stream.append(attn_intermediate_list[i])
    stream.append(pre_norm_hidden['last'])

    return stream, attn_intermediate_list, hidden_states


def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias = attn_mask + attn_bias

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

def from_states_to_probs(states, lm_head, tokenizer, topk=5):
    """convert states(d dimension)to token ids and probabilities(vocabulary size)
    Args:
        states: torch.Tensor, shape=(d, )
        tokenizer: transformers.Tokenizer
        layer_idx: int, layer index
        topk: int, topk
    Returns:
        token_probs: dict, token probabilities
    """
    logits = lm_head(states)
    traj_log_probs=torch.from_numpy(logits.log_softmax(dim=-1).squeeze().detach().cpu().numpy())
    topk_indices = torch.topk(traj_log_probs, k=topk)
    probs = torch.exp(traj_log_probs[topk_indices.indices])
    token_probs=[]
    for idx, prob in zip(topk_indices.indices, probs):
        token = tokenizer.decode(idx)
        token_probs.append((idx.item(), token, prob.item()))
    return token_probs