import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from peft.tuners.lora import Linear as LoraLinear
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import collections

class ForgetDataModule:
    def __init__(self, tokenizer, raw_forget_path: str):
        self.tokenizer=tokenizer;o=load_dataset("json",data_files={"train":raw_forget_path},split="train");self.forget_set=o.map(self._format_for_chat,remove_columns=o.column_names)
    def _format_for_chat(self,e):u=e.get("prompt","")or e.get("instruction","");m=e.get("response","")or e.get("output","");msg=[{"role":"user","content":u},{"role":"model","content":m}];t=self.tokenizer.apply_chat_template(msg,tokenize=False,add_generation_prompt=False)+self.tokenizer.eos_token;return{"text":t}
    def get_dataloader(self,batch_size):return DataLoader(self.forget_set,batch_size=batch_size,collate_fn=lambda x:self.tokenizer([i['text']for i in x],return_tensors="pt",padding=True,truncation=True,max_length=2048,add_special_tokens=False))
activations = {}
gradients = {}
def get_activation_hook(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook
def get_gradient_hook(name):
    def hook(model, grad_input, grad_output):
        gradients[name] = grad_output[0].detach()
    return hook

def main():
    parser = argparse.ArgumentParser(description="Analyze neuron activations across all LoRA layers.")
    parser.add_argument('--model_folder', type=str, default='/root/autodl-tmp/model/gemma-2-9b-it')
    parser.add_argument('--lora_adapter_path', type=str, required=True)
    parser.add_argument('--raw_forget_path', type=str, required=True)
    parser.add_argument('--output_file', type=str, default='harmful_neurons.pt')
    parser.add_argument('--top_k_neurons', type=int, default=8,
                        help="The absolute number of top harmful neurons to identify in each layer.")
    
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("Loading model...")
    if "gemma" in args.model_folder:
        base_model = AutoModelForCausalLM.from_pretrained(args.model_folder, torch_dtype=torch.bfloat16, device_map="auto",attn_implementation="eager")
    else:
        base_model = AutoModelForCausalLM.from_pretrained(args.model_folder, torch_dtype=torch.bfloat16, device_map="auto",attn_implementation="flash_attention_2")
    
    model = PeftModel.from_pretrained(base_model, args.lora_adapter_path)
    
    model.eval()
    model.enable_input_require_grads()
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_folder)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
    

    target_layers = [];print("Identifying all LoRA layers to hook...")
    for name, module in model.named_modules():
        if isinstance(module, LoraLinear): target_layers.append(f"{name}.lora_A.default")
    print(f"Found {len(target_layers)} target LoRA layers to analyze.")
    hooks = []; modules_dict = dict(model.named_modules())
    for name in target_layers:
        if name in modules_dict:
            module = modules_dict[name]; hooks.append(module.register_forward_hook(get_activation_hook(name))); hooks.append(module.register_full_backward_hook(get_gradient_hook(name)))
        else: print(f"Warning: Could not find module {name} to attach hook.")
            

    data_module = ForgetDataModule(tokenizer, args.raw_forget_path); dataloader = data_module.get_dataloader(batch_size=1)
    neuron_attributions = collections.defaultdict(lambda: None)
    print("Calculating neuron attributions...")
    for batch in tqdm(dataloader, desc="Analyzing samples"):
        batch = {k: v.to(device) for k, v in batch.items()}; model.zero_grad()
        loss = model(**batch, labels=batch['input_ids']).loss; loss.backward()
        for name in target_layers:
            if name not in activations or name not in gradients: continue
            attribution = (activations[name] * gradients[name]).abs().sum(dim=(0, 1))
            if neuron_attributions[name] is None: neuron_attributions[name] = attribution
            else: neuron_attributions[name] += attribution
    for hook in hooks: hook.remove()


    harmful_neuron_indices = {}
    print("Identifying top-K harmful neurons...")
    for name, scores in neuron_attributions.items():
        if scores is None: continue
        num_neurons = scores.shape[-1]

        k = min(args.top_k_neurons, num_neurons)
        
        top_k_indices = torch.topk(scores, k).indices
        harmful_neuron_indices[name] = top_k_indices.cpu()
        print(f"  Layer '{name}': Identified top {k} harmful neurons (rank dimension).")
    

    torch.save(harmful_neuron_indices, args.output_file)
    print(f"\nHarmful neuron indices saved to '{args.output_file}'")


if __name__ == '__main__':
    main()