from llmlingua import PromptCompressor
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import argparse
from tqdm import tqdm
import random
import numpy as np
from icecream import ic as pprint

import time

dataset2maxgen = {
    "NQ": 100,
    "HQA": 32,
    "WikiQA": 32,
}

instructions_map = {
    'base': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
}

dataset2num_tokens = {
    "NQ": 3265,
    "HQA": 1568,
    "WikiQA": 1148,
}

def seed_everything(TORCH_SEED):
  random.seed(TORCH_SEED)
  os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
  np.random.seed(TORCH_SEED)
  torch.manual_seed(TORCH_SEED)
  torch.cuda.manual_seed_all(TORCH_SEED)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

def get_args():
  parser = argparse.ArgumentParser()

  # 添加参数
  parser.add_argument('--data_path', type=str)
  parser.add_argument('--model_path', type=str)
  parser.add_argument('--model_gen_path', type=str)
  parser.add_argument('--compression_rate', type=int)
  parser.add_argument('--output_filename', type=str)
  parser.add_argument('--max_doc_tokens', type=int)
  parser.add_argument('--max_num_documents', type=int)
  parser.add_argument('--min_num_documents', type=int)
  parser.add_argument('--random_num_documents', action='store_true')
  parser.add_argument('--instruction_name', type=str)
  parser.add_argument('--batch_size', type=int)
  parser.add_argument('--prompt_filename', type=str)
  parser.add_argument('--seed', type=int, default=42)

  # 解析参数
  args = parser.parse_args()
  return args

args = get_args()

seed_everything(args.seed)

model_path = args.model_gen_path
dataset_name = args.data_path.split("/")[-2]

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
  model_path, 
  trust_remote_code=True, 
  torch_dtype=torch.bfloat16).to("cuda")

model.eval()

total_questions = []
total_generations = []
total_answers = []
total_original_num_tokens = []
total_compressed_num_tokens = []

# 第二阶段：读取phase1保存的压缩提示，进行推理
pprint(args.prompt_filename)

with open(args.prompt_filename, 'r', encoding='utf-8') as f:
    for line in f:
        item = json.loads(line)
        question = item['question']
        prompt = item['prompt']
        answers = item['answers']
        original_num_tokens = item['original_num_tokens']
        compressed_num_tokens = item['compressed_num_tokens']
        
        # 使用压缩后的提示进行推理
        inputs = tokenizer(
          prompt, 
          padding=False,
          return_tensors="pt"
        ).to("cuda")
        
        outputs = model.generate(
          **inputs, 
          do_sample=False, 
          top_p=1, 
          temperature=1, 
          max_new_tokens=dataset2maxgen[dataset_name]
        )[:, inputs["input_ids"].shape[1]:]
        
        generation = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        pprint(generation)
        
        total_questions.append(question)
        total_generations.append(generation)
        total_answers.append(answers)
        total_original_num_tokens.append(original_num_tokens)
        total_compressed_num_tokens.append(compressed_num_tokens)

# 保存最终结果
with open(args.output_filename, 'w', encoding='utf-8') as f:
  for x, y, z, o, c in zip(total_questions, total_generations, total_answers, total_original_num_tokens, total_compressed_num_tokens):
    item = {'question': x, 'generation': y, 'answers': z, 'original_num_tokens': o, 'compressed_num_tokens': c}
    json_line = json.dumps(item, ensure_ascii=False)
    f.write(json_line + '\n')  # 每行一个JSON对象 