import argparse
import json
import tqdm
from transformers import LlamaTokenizerFast

from torch.utils.data import DataLoader

from transformers import AutoModelForCausalLM, set_seed, AutoTokenizer

from read_data import *
import deepspeed
import torch
import os

from accelerate import dispatch_model

from peft import PeftConfig, PeftModel

model_name_or_path2= '../mc_output/mistral_model_longa_10_epoch_2e_5_1000_nq_i'
output_path = model_name_or_path2+'/generation_web_i.jsonl'

def test_inference(ds_model, tokenizer, prompt):
    input = tokenizer.batch_encode_plus(prompt, return_tensors="pt",padding=True,max_length=1600)
    input_ids = input.input_ids.to(model.device)
    attention_mask = input.attention_mask.to(model.device)
    logits = ds_model.generate(input_ids=input_ids,attention_mask = attention_mask,num_beams = 1, do_sample=False, max_new_tokens=100)
    return tokenizer.batch_decode([logits[i].tolist() for i in range(logits.shape[0])],skip_special_tokens=True)

model_name_or_path = 'mistralai/Mistral-7B-Instruct-v0.1' 


base_model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path, device_map="auto",torch_dtype=torch.float16
        )

model = PeftModel.from_pretrained(
            base_model, model_name_or_path2
        )



tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)


if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.unk_token_id

test_dataset = get_mc_dataset('../mc_data/web_','test',tokenizer = tokenizer,sft_i=True)

testloader = DataLoader(test_dataset,batch_size = 16)
for index,data in tqdm.tqdm(enumerate(testloader)):
    for i in range(len(data['text'])):
        data['text'][i] =  data['text'][i].split('[/INST]')[0]+'[/INST]'
    # print(data)
    output = test_inference(model, tokenizer, data['text'])
    for i in range(len(output)):
        ot = output[i].split('[/INST]')[1]
        with open(output_path,'a+') as f:
            results = json.dumps({'prediction':ot},ensure_ascii=False)
            f.write(results + "\n")

# trainloader = DataLoader(train_dataset,batch_size = 16)
# for index,data in tqdm.tqdm(enumerate(trainloader)):
#     for i in range(len(data['text'])):
#         data['text'][i] =  data['text'][i].split('[/INST]')[0]+'[/INST]'

#     output = test_inference(model, tokenizer, data['text'])
#     for i in range(len(output)):
#         ot = output[i].split('[/INST]')[1]
#         with open(train_path,'a+') as f:
#             results = json.dumps({'prediction':ot},ensure_ascii=False)
#             f.write(results + "\n")
        



