import os 
import argparse
import json
import torch
from transformers import AutoTokenizer
from transformers.models.llama import modeling_llama
import matplotlib.pyplot as plt
import numpy as np
import gc

from lxt.efficient import monkey_patch

def compute_lrp_mask(relevance, mask_ratio=0.1):
    """calculate mask based on relevance, return least important dimension indices"""
    # calculate average absolute relevance for each dimension
    dim_relevance = relevance.abs().mean(dim=(0, 1)).detach().cpu()
    
    # find dimensions with lowest relevance
    dim_size = dim_relevance.shape[0]
    mask_size = int(dim_size * mask_ratio)
    
    # get indices of least important dimensions
    _, bottom_indices = torch.topk(dim_relevance, k=mask_size, largest=False)
    
    return bottom_indices

def get_lrp_input_relevance(
    model_path='xxx/llms/meta/Llama-3.1-8B',
    device='cuda:3',
    mask_ratio=0.1,
    prompt="I have 5 cats and 3 dogs. My cats love to play with my",
    output_dir='../output',
    apply_monkey_patch=True,  # keep monkey_patch option
    iteration=0,
    model=None,
    tokenizer=None
):
    # process input text
    input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
    input_embeds = model.get_input_embeddings()(input_ids)
    input_embeds.retain_grad()
        
    # execute forward pass
    output_logits = model(inputs_embeds=input_embeds.requires_grad_(), use_cache=False).logits
    probs = torch.nn.functional.softmax(output_logits[:, -1, :], dim=-1)
    neg_log_likelihood = torch.log(probs) # not use minus since we want to maximize the likelihood
    expected_nll = torch.sum(probs * neg_log_likelihood)
    expected_nll.backward()

    # get relevance and immediately detach computation graph
    relevance = (input_embeds.grad * input_embeds).sum(dim=-1).detach()

    # clean all intermediate variables and computation graph
    del output_logits, probs, neg_log_likelihood, expected_nll
    input_embeds.grad = None
    torch.cuda.empty_cache()
    
    return relevance


def get_lrp_masks(
    model_path='xxx/llms/meta/Llama-3.1-8B',
    device='cuda:3',
    mask_ratios=None,
    input_ids=None,
    output_dir='../output',
    apply_monkey_patch=True,  # keep monkey_patch option
    iteration=0,
    model=None,
    tokenizer=None,
    input_relevance=None,
    label_context=None
):
    """use model built-in functions with monkey_patch to calculate LRP masks"""
    # create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # apply monkey_patch (if needed)
    if apply_monkey_patch and (model is None or not hasattr(model, '_monkey_patched')):
        print("apply monkey_patch to modeling_llama...")
        from lxt.efficient import monkey_patch
        monkey_patch(modeling_llama, verbose=True)
    
    # if model and tokenizer not provided, load them
    if model is None or tokenizer is None:
        print(f"load model: {model_path}")
        # use monkey_patched version of modeling_llama.LlamaForCausalLM
        model = modeling_llama.LlamaForCausalLM.from_pretrained(
            model_path,
            device_map=device,
            torch_dtype=torch.float32  # use float32 for more accurate gradients
        )
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model.eval()
        model._monkey_patched = True  # mark model as monkey_patched

    # return create_full_pass_masks(model, type='random')
    # clear cache
    torch.cuda.empty_cache()

    # init ModuleOutputs to store module inputs and outputs
    outputs_container = ModuleOutputs()
    
    # register hooks
    hooks = []
    for i, layer in enumerate(model.model.layers):
        # attention modules
        for module_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
            module = getattr(layer.self_attn, module_name)
            hook = module.register_forward_hook(
                lambda mod, inp, out, idx=i, name=module_name: 
                hook_module_output(mod, inp, out, idx, name, outputs_container)
            )
            hooks.append(hook)
        
        # MLP modules
        for module_name in ["gate_proj", "up_proj", "down_proj"]:
            module = getattr(layer.mlp, module_name)
            hook = module.register_forward_hook(
                lambda mod, inp, out, idx=i, name=module_name: 
                hook_module_output(mod, inp, out, idx, name, outputs_container)
            )
            hooks.append(hook)
    
    # get model hidden and intermediate dimensions
    hidden_size = model.config.hidden_size
    intermediate_size = model.config.intermediate_size if hasattr(model.config, 'intermediate_size') else 4 * hidden_size
    num_layers = len(model.model.layers)
    
    try:
        accumulated_lrp_scores = None
        num_samples = 1
        # add noise to the input
        for i in range(num_samples):
           # clear previous input and output data
           outputs_container.inputs = {}
           outputs_container.outputs = {}
        
           # input_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(model.device)
           # input_ids = prompt.to(model.device)
           input_ids = input_ids.to(model.device)
           base_embeds = model.get_input_embeddings()(input_ids)

           if i == 0:
              input_embeds = base_embeds
           else:
              embed_std = base_embeds.std().item()
              x_max = base_embeds.max().item()
              x_min = base_embeds.min().item()
              noise_scale = 0.01
              noise = torch.randn_like(base_embeds) * noise_scale * (x_max - x_min)
              input_embeds = base_embeds + noise
           # input_embeds.requires_grad_(True)  # ensure gradients needed
           # input_embeds.retain_grad()  # retain gradients for non-leaf nodes
        
           # if input_relevance is not None:
           #     # delete inputs based on importance
           #       input_embeds[input_relevance < 0] = 0
           
           output_logits = model(inputs_embeds=input_embeds, use_cache=False).logits
           
           '''
           # execute forward pass
           # take all the tokens except the last one
           shift_logits = output_logits[:, :-1, :].contiguous()
           shift_labels = input_ids[:, 1:].contiguous()
           loss_fct = torch.nn.CrossEntropyLoss(reduction='none')  # use 'none' to keep loss for each token
           losses = loss_fct(
               shift_logits.view(-1, shift_logits.size(-1)),
               shift_labels.view(-1)
            )           
           ppl = torch.exp(losses.mean())
           ppl.backward()
           '''

           # sum_all_logits
           # sum_logits = torch.sum(output_logits, dim=-1)
           # sum_logits.backward()

           #'''
           if label_context is None:
                # calculate ppl
                output_logits = output_logits[:, :-1, :]
                probs = torch.nn.functional.softmax(output_logits, dim=-1)
                neg_log_likelihood = torch.log(probs) # not use minus since we want to maximize the likelihood
                expected_nll = torch.sum(probs * neg_log_likelihood)
                expected_nll.backward()
           else:
               # label context is expected token
               start_pos = label_context if isinstance(label_context, int) else 0
               # calculate loss for answer part
               shift_logits = output_logits[:, start_pos:-1, :]
               shift_labels = input_ids[:, start_pos+1:]
               loss = torch.nn.functional.cross_entropy(
                   shift_logits.reshape(-1, shift_logits.size(-1)),
                   shift_labels.reshape(-1)
               )
               loss.backward()

           #'''
           
           # input_relevance = (input_embeds.grad * input_embeds).sum(dim=-1, keepdim=True).detach().cpu()

           # max_logits, max_indices = torch.max(output_logits[:, -1, :], dim=-1)
           # max_logits.backward()

           # execute backward pass
           # max_logits.backward(max_logits)
        
           # accumulate average LRP scores
           current_lrp_scores, current_activations = compute_lrp_scores_for_sample(model, outputs_container)

           # first iteration init accumulated lists, or accumulate to existing lists
           if accumulated_lrp_scores is None:
               # init accumulated lists
               accumulated_lrp_scores = []
               accumulated_activations = []
               for score_array, activation_array in zip(current_lrp_scores, current_activations):
                   accumulated_lrp_scores.append(score_array.copy())  # use copy to avoid reference issues
                   accumulated_activations.append(activation_array.copy())
           else:
           # accumulate scores and activations
               for j, (score_array, activation_array) in enumerate(zip(current_lrp_scores, current_activations)):
                   if j < len(accumulated_lrp_scores):  # prevent index errors
                      accumulated_lrp_scores[j] += score_array
                      accumulated_activations[j] += activation_array
                   else:
                      accumulated_lrp_scores.append(score_array.copy())
                      accumulated_activations.append(activation_array.copy())

        if accumulated_lrp_scores is not None:
            for j in range(len(accumulated_lrp_scores)):
                accumulated_lrp_scores[j] /= num_samples
                accumulated_activations[j] /= num_samples

        lrp_scores = accumulated_lrp_scores
        activations = accumulated_activations
               
        # create mask dictionary
        masks = {}
        
        assert len(mask_ratios) == num_layers * 5
        # calculate masks for each component of each layer
        for layer_idx in range(num_layers):
            # define component indices and names
            components = [
                (0, "Ind1", hidden_size),  # Attention Input
                (1, "Ind2", hidden_size),  # Attention Output
                (2, "Ind3", hidden_size),  # MLP Input
                (3, "Ind4", intermediate_size),  # MLP Middle
                (4, "Ind5", hidden_size)   # MLP Output
            ]
            
            for comp_idx, comp_name, size in components:
                score_idx = layer_idx * 5 + comp_idx
                
                if score_idx < len(lrp_scores):
                    # convert NumPy array to PyTorch Tensor
                    lrp_score = torch.from_numpy(np.abs(lrp_scores[score_idx]))
                    # lrp_score = torch.from_numpy(lrp_scores[score_idx])
                    
                    # ensure tensor dimensions correct
                    if len(lrp_score) == size:
                        # create all-ones mask
                        mask = torch.ones(size, device='cpu')
                        
                        # find neurons with lowest LRP scores
                        component_mask_ratio = mask_ratios[score_idx]
                        mask_size = int(size * component_mask_ratio)
                        if mask_size > 0:
                            # _, bottom_indices = torch.topk(lrp_score, k=mask_size, largest=False)
                            # use magnitude to find the least important neurons
                            _, bottom_indices = torch.topk(lrp_score, k=mask_size, largest=False)
                            # set least important neurons to 0 (masked)
                            mask[bottom_indices] = 0
                        
                        # store mask
                        masks[(layer_idx, comp_name)] = mask
                    else:
                        print(f"Warning: Component size mismatch for layer {layer_idx}, {comp_name}")
                        # use all-ones mask as default
                        masks[(layer_idx, comp_name)] = torch.ones(size, device='cpu')
                else:
                    print(f"Warning: Missing LRP score for layer {layer_idx}, {comp_name}")
                    # use all-ones mask as default
                    masks[(layer_idx, comp_name)] = torch.ones(size, device='cpu')
    
    finally:
        # remove hooks
        for hook in hooks:
            hook.remove()
    
    return masks, lrp_scores, activations



class ModuleOutputs:
    def __init__(self):
        self.outputs = {}
        self.inputs = {}
        
    def save_output(self, layer_idx, module_name, output):
        if layer_idx not in self.outputs:
            self.outputs[layer_idx] = {}
        
        self.outputs[layer_idx][module_name] = output
        # ensure gradients retained
        if hasattr(output, 'requires_grad') and output.requires_grad:
            output.retain_grad()
    
    def save_input(self, layer_idx, module_name, input):
        if layer_idx not in self.inputs:
            self.inputs[layer_idx] = {}
        
        if len(input) > 0:  # ensure input not empty
            self.inputs[layer_idx][module_name] = input[0]  # usually first element is actual input
            # ensure gradients retained
            if hasattr(input[0], 'requires_grad') and input[0].requires_grad:
                input[0].retain_grad()
    
    def get_output_relevance(self, layer_idx, module_name):
        if layer_idx in self.outputs and module_name in self.outputs[layer_idx]:
            output = self.outputs[layer_idx][module_name]
            if output.grad is not None:
                # calculate dot product of activation and gradient as relevance score
                relevance = (output * output.grad).float().detach().cpu()
                return relevance
        return None
    
    def get_input_relevance(self, layer_idx, module_name):
        if layer_idx in self.inputs and module_name in self.inputs[layer_idx]:
            input_tensor = self.inputs[layer_idx][module_name]
            if input_tensor.grad is not None:
                # calculate dot product of activation and gradient as relevance score
                relevance = (input_tensor * input_tensor.grad).float().detach().cpu()
                return relevance
        return None

    def get_input_activation(self, layer_idx, module_name):
        if layer_idx in self.inputs and module_name in self.inputs[layer_idx]:
            input_tensor = self.inputs[layer_idx][module_name]
            return input_tensor.float().detach().cpu()
        return None
    
    def get_output_activation(self, layer_idx, module_name):
        if layer_idx in self.outputs and module_name in self.outputs[layer_idx]:
            output = self.outputs[layer_idx][module_name]
            return output.float().detach().cpu()
        return None

# hook function
def hook_module_output(module, input, output, layer_idx, module_name, outputs_container):
    # save input
    outputs_container.save_input(layer_idx, module_name, input)
    
    # save output (handle tuple output case)
    if isinstance(output, tuple):
        outputs_container.save_output(layer_idx, module_name, output[0])
    else:
        outputs_container.save_output(layer_idx, module_name, output)

# calculate LRP scores
def compute_lrp_scores_for_sample(model, outputs_container, input_relevance=None):
    # build LRP scores according to pruning vector index order
    all_lrp_scores = []
    all_activations = []
    
    # get model hidden and intermediate dimensions
    hidden_size = model.config.hidden_size
    intermediate_size = model.config.intermediate_size if hasattr(model.config, 'intermediate_size') else 4 * hidden_size
    
    num_layers = len(model.model.layers)
    
    
    for layer_idx in range(num_layers):
        # 1. Attention input vector (q_proj input)
        q_in_rel = outputs_container.get_input_relevance(layer_idx, "q_proj")
        q_in_act = outputs_container.get_input_activation(layer_idx, "q_proj")

        if q_in_rel is not None:
            if input_relevance is not None:
                q_in_rel = q_in_rel * input_relevance
            q_in_lrp = q_in_rel.mean(dim=(0, 1)).numpy()  # reduce to [hidden_size]
            q_in_activation = q_in_act.mean(dim=(0, 1)).numpy()

            # take last token lrp score and reduce to [hidden_size]
            # q_in_lrp = q_in_rel[:, -1, :].mean(dim=0).numpy() 

            if len(q_in_lrp) == hidden_size:
                # normalize LRP score
                attn_input_lrp = q_in_lrp #/ (np.abs(q_in_lrp).max() + 1e-10)
                all_lrp_scores.append(attn_input_lrp)
                all_activations.append(q_in_activation)
            else:
                raise ValueError(f"q_in_lrp size mismatch for layer {layer_idx}")
        else:
            raise ValueError(f"q_in_rel is None for layer {layer_idx}")
        
        # 2. Attention output vector (o_proj output)
        o_out_rel = outputs_container.get_output_relevance(layer_idx, "o_proj")
        o_out_act = outputs_container.get_output_activation(layer_idx, "o_proj")

        if o_out_rel is not None:
            if input_relevance is not None:
                o_out_rel = o_out_rel * input_relevance

            o_out_lrp = o_out_rel.mean(dim=(0, 1)).numpy()  # reduce to [hidden_size]
            o_out_activation = o_out_act.mean(dim=(0, 1)).numpy()
            # take last token lrp score and reduce to [hidden_size]
            # o_out_lrp = o_out_rel[:, -1, :].mean(dim=0).numpy() 

            if len(o_out_lrp) == hidden_size:
                # normalize
                o_out_lrp = o_out_lrp #/ (np.abs(o_out_lrp).max() + 1e-10)
                all_lrp_scores.append(o_out_lrp)
                all_activations.append(o_out_activation)
            else:
                raise ValueError(f"o_out_lrp size mismatch for layer {layer_idx}")
        else:
            raise ValueError(f"o_out_rel is None for layer {layer_idx}")
        
        # 3. MLP input vector (gate_proj input)
        gate_in_rel = outputs_container.get_input_relevance(layer_idx, "gate_proj")
        gate_in_act = outputs_container.get_input_activation(layer_idx, "gate_proj")

        if gate_in_rel is not None:
            if input_relevance is not None:
                gate_in_rel = gate_in_rel * input_relevance
            gate_in_lrp = gate_in_rel.mean(dim=(0, 1)).numpy()  # reduce to [hidden_size]
            gate_in_activation = gate_in_act.mean(dim=(0, 1)).numpy()

            # take last token lrp score and reduce to [hidden_size]
            # gate_in_lrp = gate_in_rel[:, -1, :].mean(dim=0).numpy() 

            if len(gate_in_lrp) == hidden_size:
                # normalize
                gate_in_lrp = gate_in_lrp # / (np.abs(gate_in_lrp).max() + 1e-10)
                all_lrp_scores.append(gate_in_lrp)
                all_activations.append(gate_in_activation)
            else:
                raise ValueError(f"gate_in_lrp size mismatch for layer {layer_idx}")
        else:
            raise ValueError(f"gate_in_rel is None for layer {layer_idx}")
        
        # 4. MLP middle vector (gate_proj output)
        # change to down_proj input
        down_in_rel = outputs_container.get_input_relevance(layer_idx, "down_proj")
        down_in_act = outputs_container.get_input_activation(layer_idx, "down_proj")

        if down_in_rel is not None:
            if input_relevance is not None:
                down_in_rel = down_in_rel * input_relevance
            down_in_lrp = down_in_rel.mean(dim=(0, 1)).numpy()  # reduce to [intermediate_size]
            down_in_activation = down_in_act.mean(dim=(0, 1)).numpy()
            all_lrp_scores.append(down_in_lrp)
            all_activations.append(down_in_activation)
        else:
            raise ValueError(f"down_in_rel is None for layer {layer_idx}")
        '''
        gate_out_rel = outputs_container.get_output_relevance(layer_idx, "gate_proj")
        if gate_out_rel is not None:
            if input_relevance is not None:
                gate_out_rel = gate_out_rel * input_relevance
            gate_out_lrp = gate_out_rel.mean(dim=(0, 1)).numpy()  # reduce to [intermediate_size]
            
            # take last token lrp score and reduce to [intermediate_size]
            # gate_out_lrp = gate_out_rel[:, -1, :].mean(dim=0).numpy() 

            if len(gate_out_lrp) == intermediate_size:
                # normalize
                gate_out_lrp = gate_out_lrp #/ (np.abs(gate_out_lrp).max() + 1e-10)
                all_lrp_scores.append(gate_out_lrp)
            else:
                raise ValueError(f"gate_out_lrp size mismatch for layer {layer_idx}")
        else:
            raise ValueError(f"gate_out_rel is None for layer {layer_idx}")
        '''
        # 5. MLP output vector (down_proj output)
        down_out_rel = outputs_container.get_output_relevance(layer_idx, "down_proj")
        down_out_act = outputs_container.get_output_activation(layer_idx, "down_proj")

        if down_out_rel is not None:
            if input_relevance is not None:
                down_out_rel = down_out_rel * input_relevance
            down_out_lrp = down_out_rel.mean(dim=(0, 1)).numpy()  # reduce to [hidden_size]
            down_out_activation = down_out_act.mean(dim=(0, 1)).numpy()

            # take last token lrp score and reduce to [hidden_size]
            # down_out_lrp = down_out_rel[:, -1, :].mean(dim=0).numpy() 
            
            if len(down_out_lrp) == hidden_size:
                # normalize
                down_out_lrp = down_out_lrp #/ (np.abs(down_out_lrp).max() + 1e-10)
                all_lrp_scores.append(down_out_lrp)
                all_activations.append(down_out_activation)
            else:
                raise ValueError(f"down_out_lrp size mismatch for layer {layer_idx}")
        else:
            raise ValueError(f"down_out_rel is None for layer {layer_idx}")
    
    return all_lrp_scores, all_activations


def create_full_pass_masks(model, type='full', ratio=0.1):
    """create full pass mask dictionary, don't mask any neurons"""
    masks = {}
    
    # determine correct model structure path to get layer count
    if hasattr(model, 'model') and hasattr(model.model, 'layers'):
        # LlamaForCausalLM structure
        layers = model.model.layers
    elif hasattr(model, 'layers'):
        # LlamaModel structure
        layers = model.layers
    else:
        raise ValueError("cannot determine model structure, check model type")
    
    num_layers = len(layers)
    
    # get hidden dimension size (for creating masks)
    hidden_size = model.config.hidden_size
    
    # for MLP middle layers, size is usually 4x hidden_size
    # can get actual value from model config or first layer
    if hasattr(model.config, 'intermediate_size'):
        intermediate_size = model.config.intermediate_size
    else:
        # try to get from first layer
        if hasattr(layers[0].mlp, 'gate_proj'):
            intermediate_size = layers[0].mlp.gate_proj.out_features
        else:
            # default to 4x hidden dimension
            intermediate_size = 4 * hidden_size
    
    component_sizes = {
        "Ind1": hidden_size,        # attention input
        "Ind2": hidden_size,        # attention output+residual
        "Ind3": hidden_size,        # MLP input
        "Ind4": intermediate_size,  # MLP middle activation
        "Ind5": hidden_size         # MLP output+residual
    }
    
    # create all-ones masks for each component of each layer
    # if type == 'random', generate random mask, ratio is mask proportion
    for layer_idx in range(num_layers):
        for comp_name, size in component_sizes.items():
            if type == 'full':
                masks[(layer_idx, comp_name)] = torch.ones(size, device='cpu')
            elif type == 'random':
                # use torch.bernoulli to generate random masks
                masks[(layer_idx, comp_name)] = torch.bernoulli(torch.ones(size, device='cpu') * (1-ratio))
    
    return masks


if __name__ == "__main__":
    get_lrp_masks()