import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from QA_text_dataset_baecon import InferDataset
import os
import argparse
from tqdm import tqdm
import random
import numpy as np

from chat import apply_chat_template

from icecream import ic as pprint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset2maxgen = {
    "NQ": 100,
    "HQA": 32,
    "WikiQA": 32,
}


def seed_everything(TORCH_SEED):
    random.seed(TORCH_SEED)
    os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    torch.manual_seed(TORCH_SEED)
    torch.cuda.manual_seed_all(TORCH_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_args():
    parser = argparse.ArgumentParser()

    # 添加参数
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--compression_rate', type=int)
    parser.add_argument('--output_filename', type=str)
    parser.add_argument('--max_doc_tokens', type=int)
    parser.add_argument('--max_num_documents', type=int)
    parser.add_argument('--min_num_documents', type=int)
    parser.add_argument('--random_num_documents', action='store_true')
    parser.add_argument('--instruction_name', type=str)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--seed', type=int, default=42)

    # 解析参数
    args = parser.parse_args()
    return args


def load_data(
    args,
    model,
    tokenizer
):

    dataset = InferDataset(
        filepath=args.data_path,
        model=model,
        tokenizer=tokenizer,
        max_doc_tokens=args.max_doc_tokens,
        max_num_documents=args.max_num_documents,
        min_num_documents=args.min_num_documents,
        random_num_documents=args.random_num_documents,
        instruction_name=args.instruction_name,
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=dataset.collate_fn,
    )

    return data_loader


def generate_answer(
    prompt,
    model,
    tokenizer,
    eos_token_id,
    dataset_name
):
    prompt = apply_chat_template(
        "llama-2",
        messages=[{'role': 'user', 'content': prompt[0]}],
        tokenizer=tokenizer,
        add_generation_prompt=True,
    ).raw

    inputs = tokenizer(
        prompt,
        padding=True,
        return_tensors="pt"
    ).to(device)

    org_size = len(inputs["input_ids"][0])
    outputs = model.generate(
        **inputs,
        do_sample=False,
        top_p=1,
        temperature=1,
        max_new_tokens=dataset2maxgen[dataset_name],
        eos_token_id=eos_token_id,
        begin_suppress_tokens=eos_token_id,
    )[:, inputs["input_ids"].shape[1]:]
    generation = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    pprint(generation)
    # exit(0)
    return generation, org_size


args = get_args()

seed_everything(args.seed)

tokenizer = AutoTokenizer.from_pretrained(
    args.model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    args.model_path,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    ultragist_ratio=[args.compression_rate]
).to(device)

# stop generation for QA tasks when \n appears
if hasattr(model, "generation_config"):
    eos_token_id = model.generation_config.eos_token_id
else:
    eos_token_id = tokenizer.eos_token_id
if isinstance(eos_token_id, int):
    eos_token_id = [eos_token_id]
eos_token_id.append(tokenizer.encode("\n", add_special_tokens=False)[-1])

data_loader = load_data(
    args,
    model,
    tokenizer
)

model = model.eval()

with torch.no_grad():


    dataset_name = args.data_path.split("/")[-2]
    pprint(dataset_name)
    total_generations = []
    total_questions = []
    total_answers = []
    total_original_num_tokens = []
    total_compressed_num_tokens = []
    
    ultragist_size_list = []
    raw_size_list = []
    sink_size_list = []
    org_size_list = []

    for batch in tqdm(data_loader):
        try:
            model.memory.reset()
            # pprint(batch['prompt'])
            pprint(batch['question'])
            batch_generation, org_size = generate_answer(
                batch['prompt'],
                model,
                tokenizer,
                eos_token_id,
                dataset_name
            )
            
            compressed_memory = model.memory.get_memory()
            pprint(compressed_memory[0][0].shape[2])
            
            ultragist_size, raw_size, sink_size = model.memory.get_memory_size()
            ultragist_size_list.append(ultragist_size)
            raw_size_list.append(raw_size)
            sink_size_list.append(sink_size)
            org_size_list.append(org_size)

            pprint(batch['answers'])
            total_generations += batch_generation

            total_questions += batch['question']
            total_answers += batch['answers']
            total_original_num_tokens += [org_size]
            total_compressed_num_tokens.append([ultragist_size, raw_size, sink_size, compressed_memory[0][0].shape[2]])
            
        except Exception as e:
            pprint(e)
            total_questions.append(batch['question'])
            total_generations.append('')
            total_answers.append(batch['answers'])
            ultragist_size_list.append(0)
            raw_size_list.append(0)
            sink_size_list.append(0)
            org_size_list.append(0)
            with open(str(args.compression_rate) + args.output_filename + '_error', 'w', encoding='utf-8') as f:
                # 把错误的单个数据写入文件
                f.write(json.dumps(batch, ensure_ascii=False))
                
        # break
                
    with open(args.output_filename, 'w', encoding='utf-8') as f:
        for x, y, z, o, c in zip(total_questions, total_generations, total_answers, total_original_num_tokens, total_compressed_num_tokens):
            item = {'question': x, 'generation': y, 'answers': z, 'original_num_tokens': o, 'compressed_num_tokens': c}
            json_line = json.dumps(item, ensure_ascii=False)
            f.write(json_line + '\n')  # 每行一个 JSON 对象

    with open(args.output_filename + '_size', 'w', encoding='utf-8') as f:
        for x, y, z, w in zip(ultragist_size_list, raw_size_list, sink_size_list, org_size_list):
            item = {'ultragist_size': x, 'raw_size': y,
                    'sink_size': z, 'org_size': w}
            json_line = json.dumps(item, ensure_ascii=False)
            f.write(json_line + '\n')  # 每行一个 JSON 对象
