import argparse
import json
import tqdm
from transformers import LlamaTokenizerFast

from torch.utils.data import DataLoader

from transformers import AutoModelForCausalLM, set_seed, AutoTokenizer

from read_data import *
import deepspeed
import torch
import os

from accelerate import dispatch_model

from peft import PeftConfig, PeftModel

model_name_or_path2= '../output/sft_trained_model_longa_10_epoch_2e_5_1000'

output_path2 = model_name_or_path2+'/generation_web_i.jsonl'

def test_inference(ds_model, tokenizer, prompt):
    # 执行模型推理
    input = tokenizer.batch_encode_plus(prompt, return_tensors="pt",padding=True,max_length=1600)
    input_ids = input.input_ids.to(model.device)
    attention_mask = input.attention_mask.to(model.device)
    logits = ds_model.generate(input_ids=input_ids,attention_mask = attention_mask,num_beams = 1, do_sample=False, max_new_tokens=100)
    return tokenizer.batch_decode([logits[i].tolist() for i in range(logits.shape[0])],skip_special_tokens=True)

model_name_or_path = 'mistralai/Mistral-7B-Instruct-v0.1' 

base_model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path, device_map="auto",torch_dtype=torch.float16
        )

model = PeftModel.from_pretrained(
            base_model, model_name_or_path2
        )

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)


if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.unk_token_id

# train_dataset = get_trivia_dataset('/home/ubuntu/dpo_training/trivia_data/','train',tokenizer=tokenizer,sft_i = True)
test_dataset = get_trivia_dataset('../openbookqa_data/','web_test',tokenizer = tokenizer,sft_i=True)

testloader = DataLoader(test_dataset,batch_size = 16)
for index,data in tqdm.tqdm(enumerate(testloader)):
    for i in range(len(data['text'])):
        data['text'][i] =  data['text'][i].split('[/INST]')[0]+'[/INST]'
    # print(data)
    output = test_inference(model, tokenizer, data['text'])
    for i in range(len(output)):
        ot = output[i].split('[/INST]')[1]
        with open(output_path2,'a+') as f:
            results = json.dumps({'prediction':ot},ensure_ascii=False)
            f.write(results + "\n")



