# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

# from accelerate import init_empty_weights, load_checkpoint_and_dispatch

import fire
import json
import os
import sys
import re
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig
from datasets import load_dataset

from accelerate.utils import is_xpu_available

import lm_eval
from lm_eval.tasks import TaskManager
from lm_eval.models.huggingface import HFLM
import json
from models import (
    LlamaForCausalLM, 
    MistralForCausalLM, 
    Qwen3ForCausalLM, 
    TEALLlamaForCausalLM, 
    DeepseekForCausalLM,
    DeepseekV2ForCausalLM,
    DeepseekV2Config,
    DeepseekConfig,
    OlmoeForCausalLM
)
from transformers import AutoModelForCausalLM

# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["NCCL_P2P_DISABLE"] = "1"
# os.environ["OMP_NUM_THREADS"] = "1"
os.environ["HF_ALLOW_CODE_EVAL"] = "1"

from eval.quantization import quantization_config  as QUANT_CONFIG
from peft import PeftModel

DEEPSEEK_CONFIG = {
    'deepseek-ai/deepseek-moe-16b-base': DeepseekConfig,
    'deepseek-ai/DeepSeek-V2-Lite': DeepseekV2Config
}

def main(
    model_name,
    task_name,
    performance_dir: str='./performance/',
    peft_model: str=None,
    quantization: str = None, # Options: 4bit, 8bit
    max_new_tokens =256, #The maximum numbers of tokens to generate
    min_new_tokens:int=0, #The minimum numbers of tokens to generate
    prompt_file: str=None,
    seed: int=42, #seed value for reproducibility
    safety_score_threshold: float=0.5,
    do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
    use_cache: bool=True,  #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
    top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
    temperature: float=1.0, # [optional] The value used to modulate the next token probabilities.
    top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
    repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
    length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
    enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
    enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
    enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
    use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
    enable_llamaguard_content_safety: bool = False,
    prune_method: str = None,
    num_extra_neurons: int = 1,
    pruning_ratio: float = 0.5,
    baseline: bool = False,
    full_model: bool = False,
    batch_size: int = 32,
    lora: bool = False,
    num_experts_per_tok: int = None,
    max_length=4096,
    **kwargs
):
    
    quant_config = QUANT_CONFIG()
    if quantization in ['4bit', '8bit']:
        bnb_config = quant_config.create_bnb_config(quantization)
    else:
        bnb_config = None
    sparsity_at_training = model_name.split('_')[-1]
    if num_experts_per_tok is None:
        prefix = ''
    else:
        prefix = f'-{num_experts_per_tok}'
    if 'llama-1b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'llama-1b')):
            os.mkdir(os.path.join(performance_dir, 'llama-1b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'llama-1b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'llama-1b', f'{task_name}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'llama-1b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "meta-llama/Llama-3.2-1B-Instruct"
        CasualLM = LlamaForCausalLM
    elif 'olmoe' == model_name:
        if not os.path.exists(os.path.join(performance_dir, 'olmoe')):
            os.mkdir(os.path.join(performance_dir, 'olmoe'))
        if isinstance(task_name, str):
            destination_file = os.path.join(performance_dir, 'olmoe', f'BASE/{num_experts_per_tok}/{task_name}.json')
        else:
            destination_file = os.path.join(performance_dir, 'olmoe', f'Base{prefix}-all_tasks.json')
        model_id = "allenai/OLMoE-1B-7B-0125-Instruct"
        CasualLM = AutoModelForCausalLM
    elif 'olmoe' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'olmoe')):
            os.mkdir(os.path.join(performance_dir, 'olmoe'))
        if isinstance(task_name, str):
            destination_file = os.path.join(performance_dir, 'olmoe', f'SPON/{num_experts_per_tok}/{task_name}.json')
        else:
            destination_file = os.path.join(performance_dir, 'olmoe', f'SPON{prefix}-all_tasks.json')
        # print(destination_file)
        # exit()
        model_id = "allenai/OLMoE-1B-7B-0125-Instruct"
        CasualLM = OlmoeForCausalLM
    elif 'deepseekv2' == model_name:
        if not os.path.exists(os.path.join(performance_dir, 'deepseekv2')):
            os.mkdir(os.path.join(performance_dir, 'deepseekv2'))
        if isinstance(task_name, str):
            destination_file = os.path.join(performance_dir, 'deepseekv2', f'BASE/{num_experts_per_tok}/{task_name}.json')
        else:
            destination_file = os.path.join(performance_dir, 'deepseekv2', f'Base{prefix}-all_tasks.json')
        model_id = "deepseek-ai/DeepSeek-V2-Lite"
        CasualLM = AutoModelForCausalLM
    elif 'deepseekv2' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'deepseekv2')):
            os.mkdir(os.path.join(performance_dir, 'deepseekv2'))
        if isinstance(task_name, str):
            destination_file = os.path.join(performance_dir, 'deepseekv2', f'SPON/{num_experts_per_tok}/{task_name}.json')
        else:
            destination_file = os.path.join(performance_dir, 'deepseekv2', f'SPON{prefix}-all_tasks.json')
        # print(destination_file)
        # exit()
        model_id = "deepseek-ai/DeepSeek-V2-Lite"
        CasualLM = DeepseekV2ForCausalLM
    elif 'deepseek' == model_name:
        if not os.path.exists(os.path.join(performance_dir, 'deepseek')):
            os.mkdir(os.path.join(performance_dir, 'deepseek'))
        destination_file = os.path.join(performance_dir, 'deepseek', f'all_tasks.json')
        model_id = "deepseek-ai/deepseek-moe-16b-base"
        CasualLM = AutoModelForCausalLM
    elif 'deepseek' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'deepseek')):
            os.mkdir(os.path.join(performance_dir, 'deepseek'))
        if isinstance(task_name, str):
            destination_file = os.path.join(performance_dir, 'deepseek', f'SPON-{task_name}.json')
        else:
            destination_file = os.path.join(performance_dir, 'deepseek', f'SPON-all_tasks.json')
        # print(destination_file)
        # exit()
        model_id = "deepseek-ai/deepseek-moe-16b-base"
        CasualLM = DeepseekForCausalLM
    elif 'llama-3b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'llama-3b')):
            os.mkdir(os.path.join(performance_dir, 'llama-3b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'llama-3b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'llama-3b', f'{task_name}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'llama-3b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "meta-llama/Llama-3.2-3B-Instruct"
        CasualLM = LlamaForCausalLM
    elif 'llama-8b-peft' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'llama-8b-peft')):
            os.mkdir(os.path.join(performance_dir, 'llama-8b-peft'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'llama-8b-peft', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'llama-8b-peft', f'{task_name}_{pruning_ratio}_baseline_{sparsity_at_training}.json')
        else:
            destination_file = os.path.join(performance_dir, 'llama-8b-peft', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "meta-llama/Llama-3.1-8B-Instruct"
        CasualLM = LlamaForCausalLM
    elif 'teal-llama-8b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'teal-llama-8b')):
            os.mkdir(os.path.join(performance_dir, 'teal-llama-8b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'teal-llama-8b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'teal-llama-8b', f'{task_name}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'teal-llama-8b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "meta-llama/Llama-3.1-8B-Instruct"
        CasualLM = TEALLlamaForCausalLM
    elif 'lora-llama-8b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'lora-llama-8b')):
            print(performance_dir)
            os.mkdir(os.path.join(performance_dir, 'lora-llama-8b'))
        destination_file = os.path.join(performance_dir, 'lora-llama-8b', f'{task_name}_rank_{num_extra_neurons}.json')
        model_id = "meta-llama/Llama-3.1-8B-Instruct"
        CasualLM = LlamaForCausalLM
    elif 'llama-70b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'llama-70b')):
            os.mkdir(os.path.join(performance_dir, 'llama-70b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'llama-70b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'llama-70b', f'{task_name}_{pruning_ratio}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'llama-70b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "meta-llama/Llama-3.3-70B-Instruct"
        CasualLM = LlamaForCausalLM
    elif 'llama-8b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'llama-8b')):
            os.mkdir(os.path.join(performance_dir, 'llama-8b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'llama-8b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'llama-8b', f'{task_name}_{pruning_ratio}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'llama-8b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "meta-llama/Llama-3.1-8B-Instruct"
        CasualLM = LlamaForCausalLM
    elif 'llama-7b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'llama-7b')):
            os.mkdir(os.path.join(performance_dir, 'llama-7b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'llama-7b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'llama-7b', f'{task_name}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'llama-7b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "meta-llama/Llama-2-7b-hf" #"meta-llama/Llama-2-7b-chat-hf"
        CasualLM = LlamaForCausalLM
    elif 'mistral-7b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'mistral-7b')):
            os.mkdir(os.path.join(performance_dir, 'mistral-7b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'mistral-7b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'mistral-7b', f'{task_name}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'mistral-7b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "mistralai/Mistral-7B-Instruct-v0.2"
        CasualLM = MistralForCausalLM
    elif 'mistralv3-7b-peft' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'mistralv3-7b-peft')):
            os.mkdir(os.path.join(performance_dir, 'mistralv3-7b-peft'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'mistralv3-7b-peft', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'mistralv3-7b-peft', f'{task_name}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'mistralv3-7b-peft', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "mistralai/Mistral-7B-Instruct-v0.3"
        CasualLM = MistralForCausalLM
    elif 'mistralv3-7b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'mistralv3-7b')):
            os.mkdir(os.path.join(performance_dir, 'mistralv3-7b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'mistralv3-7b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'mistralv3-7b', f'{task_name}_{pruning_ratio}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'mistralv3-7b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "mistralai/Mistral-7B-Instruct-v0.3"
        CasualLM = MistralForCausalLM
    elif 'qwen3-32b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'qwen3-32b')):
            os.mkdir(os.path.join(performance_dir, 'qwen3-32b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'qwen3-32b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'qwen3-32b', f'{task_name}_{pruning_ratio}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'qwen3-32b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "Qwen/Qwen3-32b"
        CasualLM = Qwen3ForCausalLM
    elif 'qwen3-8b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'qwen3-8b')):
            os.mkdir(os.path.join(performance_dir, 'qwen3-8b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'qwen3-8b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'qwen3-8b', f'{task_name}_{pruning_ratio}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'qwen3-8b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "Qwen/Qwen3-8B"
        CasualLM = Qwen3ForCausalLM
    elif 'qwen3-1.7b' in model_name:
        if not os.path.exists(os.path.join(performance_dir, 'qwen3-1.7b')):
            os.mkdir(os.path.join(performance_dir, 'qwen3-1.7b'))
        if full_model:
            destination_file = os.path.join(performance_dir, 'qwen3-1.7b', f'{task_name}.json')
        elif baseline:
            destination_file = os.path.join(performance_dir, 'qwen3-1.7b', f'{task_name}_baseline.json')
        else:
            destination_file = os.path.join(performance_dir, 'qwen3-1.7b', f'{task_name}_{pruning_ratio}_{num_extra_neurons}_extra_neurons_{sparsity_at_training}.json')
        model_id = "Qwen/Qwen3-1.7B"
        CasualLM = Qwen3ForCausalLM
    
    if model_name == "deepseek" or model_name == "deepseekv2":
        config = DEEPSEEK_CONFIG[model_id].from_pretrained(model_id)

        if num_experts_per_tok is not None:
            config.num_experts_per_tok = num_experts_per_tok
        model = AutoModelForCausalLM.from_pretrained(model_id, config=config, device_map='cuda:0', quantization_config=bnb_config, trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    elif 'deepseek' in model_name:
        # print(bnb_config)
        # exit()
        config = DEEPSEEK_CONFIG[model_id].from_pretrained(model_name)

        if num_experts_per_tok is not None:
            config.num_experts_per_tok = num_experts_per_tok
        model = CasualLM.from_pretrained(model_name, config=config, device_map='cuda:0', quantization_config=bnb_config, trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    elif model_name == "olmoe":
        config = AutoConfig.from_pretrained(model_id)

        if num_experts_per_tok is not None:
            config.num_experts_per_tok = num_experts_per_tok
        model = AutoModelForCausalLM.from_pretrained(model_id, config=config, device_map='cuda:0', quantization_config=bnb_config, trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    elif 'olmoe' in model_name:
        # print(bnb_config)
        # exit()
        config = AutoConfig.from_pretrained(model_id)

        if num_experts_per_tok is not None:
            config.num_experts_per_tok = num_experts_per_tok
        model = CasualLM.from_pretrained(model_name, config=config, device_map='cuda:0', quantization_config=bnb_config, trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    
    elif full_model:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', quantization_config=bnb_config)
        try: 
            model = model.to("cuda:0")
        except:
            pass
    elif baseline:
        if bnb_config is None:
            model = CasualLM.from_pretrained(
                model_id,
                pruning_ratio=1.0-pruning_ratio, 
                device_map="cuda:0",
                torch_dtype=torch.bfloat16,
            )
        else:
            model = CasualLM.from_pretrained(
                model_id,
                pruning_ratio=1.0-pruning_ratio, 
                device_map="cuda:0",
                quantization_config=bnb_config
            )
        tokenizer = AutoTokenizer.from_pretrained(model_id, device_map='auto')
        if 'teal' not in model_name:
            model.bias_initialization()
        try: 
            model = model.to("cuda:0")
        except:
            pass
    elif 'peft' in model_name:
        model = CasualLM.from_pretrained(
            model_id, 
            pruning_ratio=1.0-pruning_ratio, 
            device_map='auto',
            quantization_config=bnb_config
        )
        # model.bias_initialization()
        model = PeftModel.from_pretrained(model, model_name)
    else:
        if 'llama' in model_name:
            if bnb_config is None:
                model = CasualLM.from_pretrained(
                    model_name, 
                    num_extra_neurons=num_extra_neurons, 
                    pruning_ratio=1.0-pruning_ratio, 
                    lora=lora,
                    device_map='auto',
                    torch_dtype=torch.bfloat16,
                )
            else:
                model = CasualLM.from_pretrained(
                    model_name, 
                    num_extra_neurons=num_extra_neurons, 
                    pruning_ratio=1.0-pruning_ratio, 
                    lora=lora,
                    device_map='auto',
                    quantization_config=bnb_config
                )
        else:
            if bnb_config is None:
                model = CasualLM.from_pretrained(
                    model_name, 
                    num_extra_neurons=num_extra_neurons, 
                    pruning_ratio=1.0-pruning_ratio, 
                    device_map='auto',
                    torch_dtype=torch.bfloat16,
                )
            else:
                model = CasualLM.from_pretrained(
                    model_name, 
                    num_extra_neurons=num_extra_neurons, 
                    pruning_ratio=1.0-pruning_ratio, 
                    device_map='auto',
                    quantization_config=bnb_config
                )
            
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        try: 
            model = model.to("cuda:0")
        except:
            pass
        # model.bias_initialization()
    # for name, param in model.named_parameters():
    #     print(name)
    # print(destination_file)
    # exit()
    lm = HFLM(model, tokenizer=tokenizer, batch_size=batch_size, max_length=max_length)
    # print(task_name, type(task_name))
    # exit()
    # print(batch_size)
    # exit()
    results = lm_eval.simple_evaluate(
        # model='hf', 
        # model_args={
        #     'pretrained': model, 
        #     'tokenizer': tokenizer, 
        #     'backend': 'causal', 
        # },# f'pretrained="meta-llama/Llama-3.2-1B-Instruct",backend=causal'
        model=lm,
        tasks=list(task_name) if isinstance(task_name, tuple) else task_name, 
        task_manager=TaskManager({'pretrained': model_id, 'backend': 'causal', 'batch_size': batch_size}),
        # device='cuda:0', 
        batch_size=batch_size,
        confirm_run_unsafe_code=True if task_name == 'humaneval' else False,
    )
    # if 'structural' in model_name:
    #     if not os.path.exists(os.path.join(performance_dir, 'structural')):
    #         os.mkdir(os.path.join(performance_dir, 'structural'))
    #     destination_file = os.path.join(performance_dir, 'structural', f'{task_name}_{num_extra_neurons}_extra_neurons.json')
    # elif 'fisher' in model_name:
    #     if not os.path.exists(os.path.join(performance_dir, 'fisher')):
    #         os.mkdir(os.path.join(performance_dir, 'fisher'))
    #     destination_file = os.path.join(performance_dir, 'fisher', f'{task_name}.json')
    # elif 'l2' in model_name:
    #     if not os.path.exists(os.path.join(performance_dir, 'l2')):
    #         os.mkdir(os.path.join(performance_dir, 'l2'))
    #     destination_file = os.path.join(performance_dir, 'l2', f'{task_name}.json')
    # elif 'lens' in model_name:
    #     if not os.path.exists(os.path.join(performance_dir, 'lens')):
    #         os.mkdir(os.path.join(performance_dir, 'lens'))
    #     destination_file = os.path.join(performance_dir, 'lens', f'{task_name}.json')
    # elif 'sparsity' in model_name:
    #     if not os.path.exists(os.path.join(performance_dir, 'sparse')):
    #         os.mkdir(os.path.join(performance_dir, 'sparse'))
    #     destination_file = os.path.join(performance_dir, 'sparse', f'{task_name}.json')
    # elif prune_method == 'wanda':
    #     if not os.path.exists(os.path.join(performance_dir, 'wanda')):
    #         os.mkdir(os.path.join(performance_dir, 'wanda'))
    #     destination_file = os.path.join(performance_dir, 'wanda', f'{task_name}.json')
    # else:
    #     if not os.path.exists(os.path.join(performance_dir, 'full')):
    #         os.mkdir(os.path.join(performance_dir, 'full'))
    #     destination_file = os.path.join(performance_dir, 'full', f'{task_name}.json')
    
    with open(destination_file, 'w') as fp:
        json.dump(results['results'], fp)
    # # Set the seeds for reproducibility
    # if is_xpu_available():
    #     torch.xpu.manual_seed(seed)
    # else:
    #     torch.cuda.manual_seed(seed)
    # torch.manual_seed(seed)
        
    # model = LlamaForCausalLM.from_pretrained(model_name)
    
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # model = model.to(device, torch.bfloat16)

    # tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # if 'squad' in dataset_name:
    #     dataset = load_dataset(dataset_name)['validation']

    #     chats = []
    #     answers = []
    #     for instance in dataset:
    #         context = instance['context']
    #         question = instance['question']
    #         answer = instance['answers']['text'][0]
            
    #         prompt = f'USER: {context}\n{question} Answer the question by using a single word or a single phrase.\nASSISTANT:'
            
    #         chats.append(tokenizer.encode(prompt))
    #         answers.append(answer)
            
    # else:
    #     dataset = load_dataset(dataset_name)
        
    

    # with torch.no_grad():
    #     for idx, chat in enumerate(chats):
                
    #         tokens= torch.tensor(chat).long().to(model.device)
    #         tokens= tokens.unsqueeze(0)
    #         attention_mask = torch.ones_like(tokens)
    #         if is_xpu_available():
    #             tokens= tokens.to("xpu:0")
    #         else:
    #             tokens= tokens.to(model.device)
    #         outputs = model.generate(
    #             input_ids=tokens,
    #             attention_mask=attention_mask,
    #             max_new_tokens=max_new_tokens,
    #             do_sample=do_sample,
    #             top_p=top_p,
    #             temperature=temperature,
    #             use_cache=use_cache,
    #             top_k=top_k,
    #             repetition_penalty=repetition_penalty,
    #             length_penalty=length_penalty,
    #             **kwargs
    #         )

    #         output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    #         # Safety check of the model output
    #         print(output_text)
    #         break



if __name__ == "__main__":
    fire.Fire(main)
    
    
# import torch
# from models import LlamaForCausalLMWithPruning
# model = LlamaForCausalLMWithPruning.from_pretrained('/data/haotian/neural-pruning/outputs/squad_sparsity')
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained('/data/haotian/neural-pruning/outputs/squad_sparsity')
# from datasets import load_dataset
# squad = load_dataset('rajpurkar/squad')
# model = model.to('cuda:0', torch.bfloat16)
# for instance in squad['validation']:
#     context = instance['context']
#     question = instance['question']
#     answer = instance['answers']['text'][0]

#     prompt = f'USER: {context}\n{question} Answer the question by only using a single word or a single phrase.\nASSISTANT:'
#     tokens = torch.tensor(tokenizer.encode(prompt)).long().unsqueeze(0).to(model.device)
#     attention_mask = torch.ones_like(tokens)

#     outputs = model.generate(
#         input_ids=tokens,
#         attention_mask=attention_mask,
#         max_new_tokens=256,
#         do_sample=True,
#         top_p=1.0,
#         temperature=1.0,
#         use_cache=True,
#         top_k=50,
#         repetition_penalty=1.0,
#         length_penalty=1
#     )
#     output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     print(output_text)
#     outputs = model(input_ids=tokens, attention_mask=attention_mask)
#     masking = outputs.masking
#     break
# activated, total = 0, 0
# for m in masking:
#     for k, v in m.items():
#         activated += v.sum().item()
#         total += torch.ones_like(v).sum().item()
# print(1- activated / total)

# import torch
# from transformers import LlamaForCausalLM, AutoTokenizer
# from datasets import load_dataset
# data = load_dataset('rajpurkar/squad')["validation"]
# model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')
# tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')
# input_texts = []
# for instance in data:
#     context = instance['context']
#     question = instance['question']
#     answer = instance['answers']['text'][0]
#     input_texts.append(f'USER: {context}\n{question} Answer the question by using a single word or a single phrase.\nASSISTANT: {answer}')
# from inference import Perplexity
# perplexity = Perplexity()
# model = model.to(torch.bfloat16)
# results = perplexity.compute(model=model,tokenizer=tokenizer, add_start_token=False, predictions=input_texts)

# from models import LlamaForCausalLMWithPruning
# model = LlamaForCausalLMWithPruning.from_pretrained('/data/haotian/neural-pruning/outputs/wiki_sparsity_l2')
# from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained('/data/haotian/neural-pruning/outputs/wiki_sparsity_l2')
# model.seqlen = 4096
# import torch
# model = model.to(torch.device("cuda:0"))
# from lib.eval import eval_ppl
# ppl_test = eval_ppl(None, model, tokenizer)
# print(f"wikitext perplexity {ppl_test}")