from llmlingua import PromptCompressor
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from QA_text_dataset_long_CNN import InferDataset
import os
import argparse
from tqdm import tqdm
import random
import numpy as np
from icecream import ic as pprint

import time


dataset2maxgen = {
    "NQ": 100,
    "HQA": 32,
    "WikiQA": 32,
}

instructions_map = {
    'base': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
}

dataset2num_tokens = {
    "NQ": 3265,
    "HQA": 1568,
    "WikiQA": 1148,
}
def seed_everything(TORCH_SEED):
  random.seed(TORCH_SEED)
  os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
  np.random.seed(TORCH_SEED)
  torch.manual_seed(TORCH_SEED)
  torch.cuda.manual_seed_all(TORCH_SEED)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

def get_args():
  parser = argparse.ArgumentParser()

  # 添加参数
  parser.add_argument('--data_path', type=str, default='/path/to/project/data/NQ/test.jsonl')
  parser.add_argument('--model_path', type=str, default='/path/to/project/llmodel/Llama-2-7b-hf')
  parser.add_argument('--model_gen_path', type=str, default='/path/to/project/llmodel/Llama-2-7b-hf')
  parser.add_argument('--compression_rate', type=int, default=4)
  parser.add_argument('--output_filename', type=str, default='/path/to/project/baselines_outputs/longllmlingua-nochat/NQ_outputs-longllmlingua-4.jsonl')
  parser.add_argument('--max_doc_tokens', type=int, default=5120)
  parser.add_argument('--max_num_documents', type=int, default=20)
  parser.add_argument('--min_num_documents', type=int, default=20)
  parser.add_argument('--random_num_documents', action='store_true')
  parser.add_argument('--instruction_name', type=str, default='base')
  parser.add_argument('--batch_size', type=int, default=1)
  parser.add_argument('--prompt_filename', type=str, default='/path/to/project/baselines_outputs/longllmlingua-nochat/NQ_prompts-longllmlingua-4.jsonl')
  parser.add_argument('--seed', type=int, default=42)

  # 解析参数
  args = parser.parse_args()
  return args

def load_data(
    args, 
    tokenizer
  ):

  dataset = InferDataset(
    filepath=args.data_path,
    tokenizer=tokenizer,
    max_doc_tokens=args.max_doc_tokens,
    max_num_documents=args.max_num_documents,
    min_num_documents=args.min_num_documents,
    random_num_documents=args.random_num_documents,
    instruction_name=args.instruction_name, 
  )

  data_loader = torch.utils.data.DataLoader(
      dataset,
      batch_size=args.batch_size,
      shuffle=False,
      collate_fn=dataset.collate_fn,
  )

  return data_loader

def get_longllmlingua_compression(
    documents_list,
    question,
    llm_lingua,
    compression_rate
):
    
    compressed_prompt = llm_lingua.compress_prompt(
        documents_list,
        question=question,
        rate= 1.0 / compression_rate,
        # Set the special parameter for LongLLMLingua
        condition_in_question="after_condition",
        reorder_context="sort",
        dynamic_context_compression_ratio=0.3, # or 0.4
        condition_compare=True,
        context_budget="+100",
        rank_method="longllmlingua",
    )
    
    return compressed_prompt

args = get_args()

seed_everything(args.seed)

model_path = args.model_path
llm_lingua = PromptCompressor(model_name=model_path)


tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# model = AutoModelForCausalLM.from_pretrained(
#   args.model_path, 
#   trust_remote_code=True, 
#   torch_dtype=torch.bfloat16).to("cuda")

# model.eval()

data_loader = load_data(
  args,
  tokenizer
)

total_prompts = []
total_questions = []
total_answers = []
total_generations = []
  
  
prompts = []
questions = []
answers = []
original_num_tokens = []
compressed_num_tokens = []

dataset_name = args.data_path.split("/")[-2]
instruction_text = instructions_map['base']

mean_ratio = 0
count = 10
print('compression_rate: ',  1.0 / args.compression_rate)
for batch in tqdm(data_loader):
    # pprint(batch)
    try:
        compressed_prompt = get_longllmlingua_compression(
            batch['prompt'][0], # A list of documents, which form the context
            batch['question'][0], # The input question
            llm_lingua,
            args.compression_rate
        )
        
        # pprint(compressed_prompt)

        prompt = compressed_prompt['compressed_prompt'] + batch['question'][0]
        
        # 分成两个阶段完成，保存第一阶段的结果
        prompts.append(prompt)
        questions.append(batch['question'])
        answers.append(batch['answers'])
        original_num_tokens.append(compressed_prompt['origin_tokens'])
        compressed_num_tokens.append(compressed_prompt['compressed_tokens'])
        
        ratio = compressed_prompt['compressed_tokens'] / compressed_prompt['origin_tokens']
        print('ratio: ', ratio)
        mean_ratio += ratio
        # count -= 1
        # if count == 0:
        #     break
    except:
        pprint(batch)
        continue

print('mean_ratio: ', mean_ratio / len(data_loader))

with open(args.prompt_filename, 'w', encoding='utf-8') as f:
    for x, y, z, o, c in zip(questions, prompts, answers, original_num_tokens, compressed_num_tokens):
        item = {'question': x, 'prompt': y, 'answers': z, 'original_num_tokens': o, 'compressed_num_tokens': c}
        json_line = json.dumps(item, ensure_ascii=False)
        f.write(json_line + '\n')  # 每行一个 JSON 对象
