# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# from accelerate import init_empty_weights, load_checkpoint_and_dispatch

import fire
import json
import os
import sys

import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
from datasets import load_dataset


from sparsify import Sae
import json
import torch.nn.functional as F
from models import LlamaForCausalLM 
from models import SAE

def main(
    model_name,
    performance_dir: str='./performance/',
    peft_model: str=None,
    quantization: str = None, # Options: 4bit, 8bit
    max_new_tokens =256, #The maximum numbers of tokens to generate
    min_new_tokens:int=0, #The minimum numbers of tokens to generate
    prompt_file: str=None,
    seed: int=42, #seed value for reproducibility
    safety_score_threshold: float=0.5,
    do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
    use_cache: bool=True,  #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
    top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
    temperature: float=1.0, # [optional] The value used to modulate the next token probabilities.
    top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
    repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
    length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
    enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
    enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
    enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
    use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
    enable_llamaguard_content_safety: bool = False,
    num_extra_neurons: int = 1,
    pruning_ratio: float = 0.5,
    model_size: str = '1b',
    layer_idx: int = 15,
    input_dependent: bool = False,
    **kwargs
):
    
    model = LlamaForCausalLM.from_pretrained(
        model_name, 
        num_extra_neurons=1, 
        pruning_ratio=1-0.5,
        input_dependent=False,
        torch_dtype=torch.bfloat16
    )
    model.set_bias()
    model = model.to('cuda:0')
    if model_size == '8b':
        
        # sae = SAE.load_from_disk('/data/haotian/neural-pruning/outputs/sae/llama-1b/layers.15.mlp.down_proj', 'cuda:0')
        # print(sae)
        for layer_idx in [23, 29]:
            
            sae = Sae.load_from_hub("EleutherAI/sae-llama-3.1-8b-64x", hookpoint=f"layers.{layer_idx}.mlp").to('cuda:0', torch.bfloat16)
            neuron_data = {}
            for name, param in model.named_parameters():
                if f'layers.{layer_idx}.' in name and 'mlp.down' in name and ('proj' in name or 'bias' in name):
                    layer_name = '.'.join(name.split('.')[1:])
                    if 'proj' in name:
                        neuron = param.data.t().clone()
                    else:
                        neuron = param.data.clone()
                        
                    if neuron.ndim <= 1:
                        neuron = neuron.unsqueeze(0)
                    normalization = model.get_post_attention_layernorm(layer_idx)
                    # print(neuron)
                    # _, top_indices, pre_acts = sae.encode(F.normalize(neuron, dim=-1))
                    _, top_indices, pre_acts = sae.encode(normalization(neuron))
                    neuron_data[layer_name] = torch.where(pre_acts.detach().cpu() > 0, 1, 0).detach().cpu().sum(-1).to(torch.bfloat16).mean(0)
                        
                    # try:
                    #     neuron_data[layer_name] = torch.where(pre_acts > 0, 1, 0).sum(-1).mean(0)
                    # except:
                    #     neuron_data[layer_name] = torch.where(pre_acts > 0, 1, 0).sum(-1)
                    # print(layer_name, top_indices.detach().cpu())
                    del top_indices
                    del pre_acts
                    torch.cuda.empty_cache()
            
            bias_data = torch.tensor([value for key, value in neuron_data.items() if 'bias' in key]).mean().item()
            proj_data = [value for key, value in neuron_data.items() if 'proj' in key][0].item()
            print(f'At layer {layer_idx}, the average number of semantic of pretrained neurons is {proj_data}, while that of extra neurons is {bias_data}')
    
    elif model_size == '1b':
        
        # sae = SAE.load_from_disk('/data/haotian/neural-pruning/outputs/sae/llama-1b/layers.15.mlp.down_proj', 'cuda:0')
        # print(sae)
        for layer_idx in range(model.config.num_hidden_layers):
            
            sae = Sae.load_from_hub("EleutherAI/sae-Llama-3.2-1B-131k", hookpoint=f"layers.{layer_idx}.mlp").to('cuda:0', torch.bfloat16)
            # sae = SAE.load_from_disk('/data/haotian/neural-pruning/outputs/sae/llama-1b/layers.15.mlp.down_proj', 'cuda:0').to('cuda:0', torch.bfloat16)
            # layer_idx = 15
            neuron_data = {}
            for name, param in model.named_parameters():
                if f'layers.{layer_idx}.' in name and 'mlp.down' in name and ('proj' in name or 'bias' in name):
                    layer_name = '.'.join(name.split('.')[1:])
                    if 'proj' in name:
                        neuron = param.data.t().clone()
                    else:
                        neuron = param.data.clone()
                        
                    if neuron.ndim <= 1:
                        neuron = neuron.unsqueeze(0)
                    # print(neuron)
                    normalization = model.get_post_attention_layernorm(layer_idx)
                    # print(neuron)
                    # _, top_indices, pre_acts = sae.encode(F.normalize(neuron, dim=-1))
                    if layer_idx == 15:    
                        _, top_indices, pre_acts = sae.encode(model.model.norm(normalization(neuron)))
                    else:
                        _, top_indices, pre_acts = sae.encode(normalization(neuron))
                    print(top_indices)
                    neuron_data[layer_name] = torch.where(pre_acts.detach().cpu() > 0, 1, 0).detach().cpu().sum(-1).to(torch.bfloat16).mean(0)
                    del top_indices
                    del pre_acts
                    torch.cuda.empty_cache()
            
            bias_data = torch.tensor([value for key, value in neuron_data.items() if 'bias' in key]).mean().item()
            proj_data = [value for key, value in neuron_data.items() if 'proj' in key][0].item()
            print(f'At layer {layer_idx}, the average number of semantic of pretrained neurons is {proj_data}, while that of extra neurons is {bias_data}')

        
        num_samples = 50
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        texts = dataset[:num_samples]["text"]
        texts = [t for t in texts if len(t.strip()) > 0]
        
        embeddings = []
        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
        num_tokens = 0
        saes = Sae.load_many("EleutherAI/sae-Llama-3.2-1B-131k")
        hidden_sae = {}
        with torch.no_grad():
            for text in texts:
                inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64, add_special_tokens=True).to(model.device)
                # print(inputs)
                # exit()
                outputs = model(**inputs, output_hidden_states=True)
                num_tokens += inputs['input_ids'].size(-1) #- 1
                for layer_idx in range(len(outputs.hidden_states)-1):
                    sae = saes[f"layers.{layer_idx}.mlp"].to('cuda:0', torch.bfloat16)
                    
                    hidden = outputs.hidden_states[layer_idx+1].squeeze()
                    # _, top_indices, pre_acts = sae.encode(F.normalize(hidden, dim=-1))
                    
                    _, top_indices, pre_acts = sae.encode(hidden)
                    # pre_acts = pre_acts[1:, :]
                    num_acts = torch.where(pre_acts.detach().cpu() > 0, 1, 0).detach().cpu()#.sum()
                    # print(num_acts.sum(-1))
                    num_acts = num_acts.sum()
                    # exit()
                    # print(f'At layer {layer_idx}, the average number of semantic is {num_acts.item()}')
                    try:
                        hidden_sae[layer_idx] += num_acts
                    except:
                        hidden_sae[layer_idx] = num_acts
                # exit()
        for layer_idx, num_acts in hidden_sae.items():
            print(f'SPON: At layer {layer_idx}, the number of activated SAE features is {num_acts / num_tokens}')
        # exit()
        print('*'*20)
                    
        hidden_sae = {}
        num_tokens = 0
        model.bias_initialization()
        hidden_sae = {}
        with torch.no_grad():
            for text in texts:
                inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64, add_special_tokens=False).to(model.device)
                outputs = model(**inputs, output_hidden_states=True)
                num_tokens += inputs['input_ids'].size(-1)
                for layer_idx in range(len(outputs.hidden_states)-1):
                    sae = saes[f"layers.{layer_idx}.mlp"].to('cuda:0', torch.bfloat16)
                    hidden = outputs.hidden_states[layer_idx+1].squeeze()
                    # _, top_indices, pre_acts = sae.encode(F.normalize(hidden, dim=-1))
                    
                    _, top_indices, pre_acts = sae.encode(hidden)
                    
                    num_acts = torch.where(pre_acts.detach().cpu() > 0, 1, 0).detach().cpu().sum()
                    # print(num_acts.sum(-1))
                    # exit()
                    # print(f'At layer {layer_idx}, the average number of semantic is {num_acts.item()}')
                    try:
                        hidden_sae[layer_idx] += num_acts
                    except:
                        hidden_sae[layer_idx] = num_acts
        for layer_idx, num_acts in hidden_sae.items():
            print(f'TEAL: At layer {layer_idx}, the number of activated SAE features is {num_acts / num_tokens}')
        
        
        print('*'*20)
        hidden_sae = {}
        num_tokens = 0
        model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", torch_dtype=torch.bfloat16).to('cuda:0')
        with torch.no_grad():
            for text in texts:
                inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=64, add_special_tokens=False).to(model.device)
                outputs = model(**inputs, output_hidden_states=True)
                num_tokens += inputs['input_ids'].size(-1)
                for layer_idx in range(len(outputs.hidden_states)-1):
                    sae = saes[f"layers.{layer_idx}.mlp"].to('cuda:0', torch.bfloat16)
                    hidden = outputs.hidden_states[layer_idx+1].squeeze()
                    # _, top_indices, pre_acts = sae.encode(F.normalize(hidden, dim=-1))
                    
                    _, top_indices, pre_acts = sae.encode(hidden)
                    # pre_acts = pre_acts[1:, :]
                    num_acts = torch.where(pre_acts.detach().cpu() > 0, 1, 0).detach().cpu().sum()
                    # print(num_acts.sum(-1))
                    # exit()
                    # print(f'At layer {layer_idx}, the average number of semantic is {num_acts.item()}')
                    try:
                        hidden_sae[layer_idx] += num_acts
                    except:
                        hidden_sae[layer_idx] = num_acts
        for layer_idx, num_acts in hidden_sae.items():
            print(f'FULL: At layer {layer_idx}, the number of activated SAE features is {num_acts / num_tokens}')
        # exit()
        
if __name__ == "__main__":
    fire.Fire(main)
    
    
# import torch
# from models import LlamaForCausalLMWithPruning
# model = LlamaForCausalLMWithPruning.from_pretrained('/data/haotian/neural-pruning/outputs/squad_sparsity')
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained('/data/haotian/neural-pruning/outputs/squad_sparsity')
# from datasets import load_dataset
# squad = load_dataset('rajpurkar/squad')
# model = model.to('cuda:0', torch.bfloat16)
# for instance in squad['validation']:
#     context = instance['context']
#     question = instance['question']
#     answer = instance['answers']['text'][0]

#     prompt = f'USER: {context}\n{question} Answer the question by only using a single word or a single phrase.\nASSISTANT:'
#     tokens = torch.tensor(tokenizer.encode(prompt)).long().unsqueeze(0).to(model.device)
#     attention_mask = torch.ones_like(tokens)

#     outputs = model.generate(
#         input_ids=tokens,
#         attention_mask=attention_mask,
#         max_new_tokens=256,
#         do_sample=True,
#         top_p=1.0,
#         temperature=1.0,
#         use_cache=True,
#         top_k=50,
#         repetition_penalty=1.0,
#         length_penalty=1
#     )
#     output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     print(output_text)
#     outputs = model(input_ids=tokens, attention_mask=attention_mask)
#     masking = outputs.masking
#     break
# activated, total = 0, 0
# for m in masking:
#     for k, v in m.items():
#         activated += v.sum().item()
#         total += torch.ones_like(v).sum().item()
# print(1- activated / total)

# import torch
# from transformers import LlamaForCausalLM, AutoTokenizer
# from datasets import load_dataset
# data = load_dataset('rajpurkar/squad')["validation"]
# model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')
# tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')
# input_texts = []
# for instance in data:
#     context = instance['context']
#     question = instance['question']
#     answer = instance['answers']['text'][0]
#     input_texts.append(f'USER: {context}\n{question} Answer the question by using a single word or a single phrase.\nASSISTANT: {answer}')
# from inference import Perplexity
# perplexity = Perplexity()
# model = model.to(torch.bfloat16)
# results = perplexity.compute(model=model,tokenizer=tokenizer, add_start_token=False, predictions=input_texts)

# from models import LlamaForCausalLMWithPruning
# model = LlamaForCausalLMWithPruning.from_pretrained('/data/haotian/neural-pruning/outputs/wiki_sparsity_l2')
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained('/data/haotian/neural-pruning/outputs/wiki_sparsity_l2')
# model.seqlen = 4096
# import torch
# model = model.to(torch.device("cuda:0"))
# from lib.eval import eval_ppl
# ppl_test = eval_ppl(None, model, tokenizer)
# print(f"wikitext perplexity {ppl_test}")