from QA_text_dataset_CNN import InferDataset
from tqdm import tqdm
import sys
import re
import time
import os
import json
import argparse
import torch
import warnings
import random
import numpy as np
from transformers import AutoTokenizer, TextStreamer, GenerationConfig, AutoModelForCausalLM
from snapkv.monkeypatch.monkeypatch import replace_llama

warnings.filterwarnings("ignore")

# import os
# # CUDAVISIBLE DEVICES
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def load_data(
    args,
    tokenizer
):
    dataset = InferDataset(
        filepath=args.data_path,
        tokenizer=tokenizer,
        max_doc_tokens=args.max_doc_tokens,
        max_num_documents=args.max_num_documents,
        min_num_documents=args.min_num_documents,
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=dataset.collate_fn,
    )

    return data_loader

dataset2maxgen = {
    "NQ": 100,
    "HQA": 32,
    "WikiQA": 32,
}

instructions_map = {
    'base': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
}

dataset2num_tokens = {
    "NQ": 3265,
    "HQA": 1568,
    "WikiQA": 1148,
}

def load_model_and_tokenizer(path, compress=False, window_size=None, max_capacity_prompt=None, kernel_size=None, pooling=None):
    tokenizer = AutoTokenizer.from_pretrained(path)
    model = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto",
        use_cache=True,
        use_flash_attention_2=True
    )
   
    
    # Configure SnapKV parameters if compression is enabled
    if compress:
        layers = len(model.model.layers)
        # Apply compression parameters to all layers
        for i in range(layers):
            model.model.layers[i].self_attn.config.window_size = window_size
            model.model.layers[i].self_attn.config.max_capacity_prompt = max_capacity_prompt
            model.model.layers[i].self_attn.config.kernel_size = kernel_size
            model.model.layers[i].self_attn.config.pooling = pooling
    
    model = model.eval()
    tokenizer.pad_token_id = tokenizer.eos_token_id
    
    return model, tokenizer

def main(args):
    model_name_or_path = args.model_name_or_path
    dataset_name = args.data_path.split("/")[-2]
    
    # Apply SnapKV monkey patching
    if args.compress:
        replace_llama()
        
    args.max_capacity_prompt = int(dataset2num_tokens[dataset_name]/args.mean_compression_rate)
    
    # Load model and tokenizer with SnapKV configuration if enabled
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model, tokenizer = load_model_and_tokenizer(
        model_name_or_path,
        compress=args.compress,
        window_size=args.window_size,
        max_capacity_prompt=args.max_capacity_prompt,
        kernel_size=args.kernel_size,
        pooling=args.pooling
    )
    
    # Load dataset
    data_loader = load_data(args, tokenizer)

    total_generations = []
    total_questions = []
    total_answers = []
    total_original_num_tokens = []
    total_compressed_num_tokens = []

    max_gen_len = dataset2maxgen[dataset_name]
    
    for batch in tqdm(data_loader):
        prompt = batch['prompt'][0]
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        original_num_tokens = len(input_ids[0])
        
        input_ids = input_ids.to(model.device)
        
        with torch.no_grad():
            # A TextStreamer prints tokens as they're being generated
            streamer = TextStreamer(tokenizer)
            generated_tokens = model.generate(
                input_ids,
                generation_config=GenerationConfig(
                    use_cache=True,
                    min_new_tokens=1,
                    max_new_tokens=max_gen_len,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                ),
                streamer=streamer,
            )
            # Decode the final generated text
            output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)[len(prompt):]
        
        # Calculate the compressed number of tokens if using SnapKV
        compressed_num_tokens = original_num_tokens
        if args.compress:
            compressed_num_tokens = args.max_capacity_prompt
        
        total_generations.append(output_text)
        total_questions.append(batch['question'][0])
        total_answers.append(batch['answers'][0])
        total_original_num_tokens.append(original_num_tokens)
        total_compressed_num_tokens.append(compressed_num_tokens)

        # break
    # Save results
    with open(args.output_filename, 'w', encoding='utf-8') as f:
        for x, y, z, o, c in zip(total_questions, total_generations, total_answers, total_original_num_tokens, total_compressed_num_tokens):
            item = {
                'question': x, 
                'generation': y, 
                'answers': z,
                'original_num_tokens': o, 
                'compressed_num_tokens': c
            }
            json_line = json.dumps(item, ensure_ascii=False)
            f.write(json_line + '\n')

def seed_everything(TORCH_SEED):
    random.seed(TORCH_SEED)
    os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    torch.manual_seed(TORCH_SEED)
    torch.cuda.manual_seed_all(TORCH_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="/path/to/project/llmodel/Llama-2-7b-chat-hf")
    
    # SnapKV compression parameters
    parser.add_argument("--compress", action="store_true", help="Whether to use SnapKV compression")
    parser.add_argument("--window_size", type=int, default=64, help="SnapKV window size")
    parser.add_argument("--max_capacity_prompt", type=int, default=100, help="SnapKV max capacity prompt")
    parser.add_argument("--kernel_size", type=int, default=13, help="SnapKV kernel size")
    parser.add_argument("--pooling", type=str, default="maxpool", choices=["maxpool", "avgpool"], help="SnapKV pooling method")
    parser.add_argument("--mean_compression_rate", type=float, default=4, help="SnapKV compression rate")

    # Dataset parameters
    parser.add_argument('--data_path', type=str, default="data/NQ/test.jsonl")
    parser.add_argument('--output_filename', type=str, default='NQ_outputs-snapkv.jsonl')
    parser.add_argument('--seed', type=int, default=42)

    parser.add_argument('--max_doc_tokens', type=int, default=512)
    parser.add_argument('--max_num_documents', type=int, default=20)
    parser.add_argument('--min_num_documents', type=int, default=20)
    parser.add_argument('--instruction_name', type=str, default='base')
    parser.add_argument('--batch_size', type=int, default=1)

    args = parser.parse_args()
    args.compress = True
    
    seed_everything(args.seed)
    main(args) 