from QA_text_dataset import InferDataset
from streaming_llm.enable_streaming_llm import enable_streaming_llm
from streaming_llm.utils import load, download_url, load_jsonl
from tqdm import tqdm
import sys
import re
import time
import os
import json
import argparse
import torch
import warnings
import random
import numpy as np
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
from attention_sinks import AutoModelForCausalLM

warnings.filterwarnings("ignore")


def load_data(
    args,
    tokenizer
):
    dataset = InferDataset(
        filepath=args.data_path,
        tokenizer=tokenizer,
        max_doc_tokens=args.max_doc_tokens,
        max_num_documents=args.max_num_documents,
        min_num_documents=args.min_num_documents,
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=dataset.collate_fn,
    )

    return data_loader


dataset2maxgen = {
    "NQ": 100,
    "HQA": 32,
    "WikiQA": 32,
}

instructions_map = {
    'base': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
}

dataset2num_tokens = {
    "NQ": 3265,
    "HQA": 1568,
    "WikiQA": 1148,
}

def main(args):
    model_name_or_path = args.model_name_or_path
    
    dataset_name = args.data_path.split("/")[-2]

    attention_sink_size = 4
    attention_sink_window_size = int(dataset2num_tokens[dataset_name]/args.mean_compression_rate) - attention_sink_size
    
    # Load the chosen model and corresponding tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        # for efficiency:
        device_map="cuda:0",
        torch_dtype=torch.float16,
        # `attention_sinks`-specific arguments:
        attention_sink_size=attention_sink_size,
        attention_sink_window_size=attention_sink_window_size, # <- Low for the sake of faster generation
    )
    model.eval()
    
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    tokenizer.pad_token_id = tokenizer.eos_token_id

    data_loader = load_data(
        args,
        tokenizer
    )

    total_prompts = []
    total_questions = []
    total_answers = []
    total_generations = []

    total_generations = []
    total_questions = []
    total_answers = []
    total_original_num_tokens = []
    total_compressed_num_tokens = []

    instruction_text = instructions_map['base']
    max_gen_len = dataset2maxgen[dataset_name]
    for batch in tqdm(data_loader):
        # pprint(batch)
        prompt = batch['prompt'][0]  # [0]
        
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        original_num_tokens = len(input_ids[0])
        
        
        # attention_sink_size = 4
        # attention_sink_window_size = int(original_num_tokens/args.mean_compression_rate) - attention_sink_size
        
        # # Load the chosen model and corresponding tokenizer
        # model = AutoModelForCausalLM.from_pretrained(
        #     model_name_or_path,
        #     # for efficiency:
        #     device_map="auto",
        #     torch_dtype=torch.float16,
        #     # `attention_sinks`-specific arguments:
        #     attention_sink_size=attention_sink_size,
        #     attention_sink_window_size=attention_sink_window_size, # <- Low for the sake of faster generation
        # )
        # model.eval()
        
        
        input_ids = input_ids.to(model.device)
        
        with torch.no_grad():
            # A TextStreamer prints tokens as they're being generated
            streamer = TextStreamer(tokenizer)
            generated_tokens = model.generate(
                input_ids,
                generation_config=GenerationConfig(
                    # use_cache=True is required, the rest can be changed up.
                    use_cache=True,
                    min_new_tokens=1,
                    max_new_tokens=max_gen_len,
                    # penalty_alpha=0.6,
                    # top_k=5,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                ),
                streamer=streamer,
            )
            # Decode the final generated text
            output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)[len(prompt):]
            
            
        total_generations += [output_text]
        total_questions += batch['question']
        total_answers += batch['answers']
        total_original_num_tokens += [original_num_tokens] 
        total_compressed_num_tokens += [attention_sink_size+attention_sink_window_size]
        # break

        

    with open(args.output_filename, 'w', encoding='utf-8') as f:
        for x, y, z, o, c in zip(total_questions, total_generations, total_answers, total_original_num_tokens, total_compressed_num_tokens):
            item = {'question': x, 'generation': y, 'answers': z,
                    'original_num_tokens': o, 'compressed_num_tokens': c}
            json_line = json.dumps(item, ensure_ascii=False)
            f.write(json_line + '\n')  # 每行一个 JSON 对象

def seed_everything(TORCH_SEED):
    random.seed(TORCH_SEED)
    os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    torch.manual_seed(TORCH_SEED)
    torch.cuda.manual_seed_all(TORCH_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path", type=str, default="/path/to/project/llmodel/Llama-2-7b-chat-hf"
    )
    parser.add_argument("--enable_streaming", action="store_true")
    parser.add_argument("--start_size", type=int, default=4)
    parser.add_argument("--recent_size", type=int, default=1000)

    parser.add_argument('--data_path', type=str, default="data/NQ/test.jsonl")
    parser.add_argument('--mean_compression_rate', type=int, default=4)
    parser.add_argument('--output_filename', type=str, default='NQ_outputs-streaming-4-as-test.jsonl')
    parser.add_argument('--seed', type=int, default=42)

    parser.add_argument('--max_doc_tokens', type=int, default=512)
    parser.add_argument('--max_num_documents', type=int, default=20)
    parser.add_argument('--min_num_documents', type=int, default=20)
    parser.add_argument('--instruction_name', type=str, default='base')
    parser.add_argument('--batch_size', type=int, default=1)

    args = parser.parse_args()
    
    args.enable_streaming = True

    seed_everything(args.seed)

    main(args)