from llmlingua import PromptCompressor
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from QA_text_dataset_long_CNN import InferDataset
import os
import argparse
from tqdm import tqdm
import random
import numpy as np
from icecream import ic as pprint

import time



dataset2maxgen = {
    "NQ": 100,
    "HQA": 32,
    "WikiQA": 32,
}

instructions_map = {
    'base': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
    'short': 'Answer the Question based on the given text. Only give me the answer and do not output any other words.',
}

dataset2num_tokens = {
    "NQ": 3265,
    "HQA": 1568,
    "WikiQA": 1148,
}


def seed_everything(TORCH_SEED):
  random.seed(TORCH_SEED)
  os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
  np.random.seed(TORCH_SEED)
  torch.manual_seed(TORCH_SEED)
  torch.cuda.manual_seed_all(TORCH_SEED)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

def get_args():
  parser = argparse.ArgumentParser()

  # 添加参数
  parser.add_argument('--data_path', type=str)
  parser.add_argument('--model_path', type=str)
  parser.add_argument('--model_gen_path', type=str)
  parser.add_argument('--compression_rate', type=int)
  parser.add_argument('--output_filename', type=str)
  parser.add_argument('--max_doc_tokens', type=int)
  parser.add_argument('--max_num_documents', type=int)
  parser.add_argument('--min_num_documents', type=int)
  parser.add_argument('--random_num_documents', action='store_true')
  parser.add_argument('--instruction_name', type=str)
  parser.add_argument('--batch_size', type=int)
  parser.add_argument('--prompt_filename', type=str)
  parser.add_argument('--seed', type=int, default=42)

  # 解析参数
  args = parser.parse_args()
  return args

args = get_args()

seed_everything(args.seed)

model_path = args.model_gen_path
dataset_name = args.data_path.split("/")[-2]


tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
  model_path, 
  trust_remote_code=True, 
  torch_dtype=torch.bfloat16).to("cuda")

model.eval()

total_prompts = []
total_questions = []
total_answers = []
total_generations = []
total_original_num_tokens = []
total_compressed_num_tokens = []

# 第二阶段：读取lingua_prompt，进行推理

pprint(args.prompt_filename)

with open(args.prompt_filename, 'r', encoding='utf-8') as f:
    for line in f:
        item = json.loads(line)
        question = item['question']
        prompt = item['prompt']
        answers = item['answers']
        original_num_tokens = item['original_num_tokens']
        compressed_num_tokens = item['compressed_num_tokens']
        
        inputs = tokenizer(
          prompt, 
          padding=False,
          return_tensors="pt"
        ).to("cuda")
        
        outputs = model.generate(
          **inputs, 
          do_sample=False, 
          top_p=1, 
          temperature=1, 
          max_new_tokens=dataset2maxgen[dataset_name]
        )[:, inputs["input_ids"].shape[1]:]
        
        generation = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        pprint(generation)
        
        total_questions.append(question)
        total_generations.append(generation)
        total_answers.append(answers)
        total_original_num_tokens += [original_num_tokens]
        total_compressed_num_tokens += [compressed_num_tokens]
        
        
    
with open(args.output_filename, 'w', encoding='utf-8') as f:
  for x, y, z, o, c in zip(total_questions, total_generations, total_answers,  total_original_num_tokens, total_compressed_num_tokens):
    item = {'question': x, 'generation': y, 'answers': z, 'original_num_tokens': o, 'compressed_num_tokens': c}
    json_line = json.dumps(item, ensure_ascii=False)
    f.write(json_line + '\n')  # 每行一个 JSON 对象
