from llmlingua import PromptCompressor
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from QA_text_dataset_CNN import InferDataset
import os
import argparse
from tqdm import tqdm
import random
import numpy as np
from icecream import ic as pprint

import time


dataset2maxgen = {
    "NQ": 100,
    "HQA": 32,
    "WikiQA": 32,
}

instructions_map = {
    'base': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
}

dataset2num_tokens = {
    "NQ": 3265,
    "HQA": 1568,
    "WikiQA": 1148,
}
def seed_everything(TORCH_SEED):
  random.seed(TORCH_SEED)
  os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
  np.random.seed(TORCH_SEED)
  torch.manual_seed(TORCH_SEED)
  torch.cuda.manual_seed_all(TORCH_SEED)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

def get_args():
  parser = argparse.ArgumentParser()

  # 添加参数
  parser.add_argument('--data_path', type=str, default='/path/to/project/data/NQ/test.jsonl')
  parser.add_argument('--model_path', type=str, default='/path/to/project/llmodel/llmlingua-2-xlm-roberta-large-meetingbank')
  parser.add_argument('--model_gen_path', type=str, default='/path/to/project/llmodel/Llama-2-7b-hf')
  parser.add_argument('--compression_rate', type=int, default=4)
  parser.add_argument('--output_filename', type=str, default='/path/to/project/baselines_outputs/llmlingua2/NQ_outputs-llmlingua2-4.jsonl')
  parser.add_argument('--max_doc_tokens', type=int, default=5120)
  parser.add_argument('--max_num_documents', type=int, default=20)
  parser.add_argument('--min_num_documents', type=int, default=20)
  parser.add_argument('--random_num_documents', action='store_true')
  parser.add_argument('--instruction_name', type=str, default='base')
  parser.add_argument('--batch_size', type=int, default=1)
  parser.add_argument('--prompt_filename', type=str, default='/path/to/project/baselines_outputs/llmlingua2/NQ_prompts-llmlingua2-4.jsonl')
  parser.add_argument('--seed', type=int, default=42)

  # 解析参数
  args = parser.parse_args()
  return args

def load_data(
    args, 
    tokenizer
  ):

  dataset = InferDataset(
    filepath=args.data_path,
    tokenizer=tokenizer,
    max_doc_tokens=args.max_doc_tokens,
    max_num_documents=args.max_num_documents,
    min_num_documents=args.min_num_documents,
    random_num_documents=args.random_num_documents,
    instruction_name=args.instruction_name, 
  )

  data_loader = torch.utils.data.DataLoader(
      dataset,
      batch_size=args.batch_size,
      shuffle=False,
      collate_fn=dataset.collate_fn,
  )

  return data_loader

def get_llmlingua2_compression(
    prompt,
    llm_lingua,
    compression_rate):

    # 使用LLMLingua2进行压缩
    compressed_result = llm_lingua.compress_prompt(
        prompt[0], 
        rate=1.0/compression_rate,
        force_tokens=['\n', '?']  # 保留换行符和问号
    )
    
    return {
        'compressed_prompt': compressed_result['compressed_prompt'],
        'origin_tokens': compressed_result['origin_tokens'],
        'compressed_tokens': compressed_result['compressed_tokens'],
    }

args = get_args()

seed_everything(args.seed)

model_path = args.model_path
llm_lingua = PromptCompressor(
    model_name=model_path,
    use_llmlingua2=True  # 使用LLMLingua2而不是LongLLMLingua
)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

data_loader = load_data(
  args,
  tokenizer
)

prompts = []
questions = []
answers = []
original_num_tokens = []
compressed_num_tokens = []

dataset_name = args.data_path.split("/")[-2]
instruction_text = instructions_map['base']

mean_ratio = 0
print('compression_rate: ', 1.0 / args.compression_rate)
for batch in tqdm(data_loader):
    try:
        compressed_prompt = get_llmlingua2_compression(
            batch['prompt'],  
            llm_lingua,
            args.compression_rate
        )
        
        prompt = compressed_prompt['compressed_prompt'] + batch['question'][0]

        # 保存第一阶段结果
        prompts.append(prompt)
        questions.append(batch['question'])
        answers.append(batch['answers'])
        original_num_tokens.append(compressed_prompt['origin_tokens'])
        compressed_num_tokens.append(compressed_prompt['compressed_tokens'])
        
        ratio = compressed_prompt['compressed_tokens'] / compressed_prompt['origin_tokens']
        print('ratio: ', ratio)
        mean_ratio += ratio
    except:
        pprint(batch)
        continue

print('mean_ratio: ', mean_ratio / len(data_loader))

# 保存所有处理结果到文件
with open(args.prompt_filename, 'w', encoding='utf-8') as f:
    for x, y, z, o, c in zip(questions, prompts, answers, original_num_tokens, compressed_num_tokens):
        item = {'question': x, 'prompt': y, 'answers': z, 'original_num_tokens': o, 'compressed_num_tokens': c}
        json_line = json.dumps(item, ensure_ascii=False)
        f.write(json_line + '\n')  # 每行一个JSON对象 