import torch
import json
import argparse
from datasets import load_dataset
from vllm import LLM, SamplingParams
parser = argparse.ArgumentParser()
parser.add_argument("--model", default='', type=str)
parser.add_argument("--dtype", default='bfloat16', type=str)
parser.add_argument("--input_test_file", type=str)
args = parser.parse_args()
model = args.model
model_max_length = 2048
llm = LLM(model=model,  tensor_parallel_size=torch.cuda.device_count(), dtype=args.dtype, trust_remote_code=True)
tokenizer = llm.get_tokenizer()

sampling_params = SamplingParams(temperature=0, top_p=1,max_tokens=model_max_length, stop=[tokenizer.eos_token])

def run_question_answer(question):
    outputs = llm.generate(question, sampling_params)
    
    outputs = [output.outputs[0].text for output in outputs]
    return outputs


def datasetFromJson(file_name):
    data_files = {}
    dataset_args = {}
    data_files["train"] = file_name
    raw_datasets = load_dataset("json",data_files=data_files,**dataset_args,)['train']
    return raw_datasets

def _process_doc_mcf(example):
    messages = example["messages"]

    for msg in messages:
        if msg['role'] == 'user':
            question = msg['content']
        if msg['role'] == 'assistant':
            answer = msg['content']
    return question, answer

def mcfGreedy(dataset):
    count = 0
    correct = 0
    zero_shot_correct = 0
    for i, data in enumerate(dataset):

        question,  gold = _process_doc_mcf(data)
        outputs = run_question_answer(question)
        #aliases = data['aliases']
        data = {}
        data['id'] = count
        data['question'] = question

        data['target'] = gold.rstrip().lstrip()
        data['prediction'] = outputs[0].rstrip().lstrip()

        
        if data['prediction'].strip()[:1].upper() == data['target'].strip()[:1].upper():
            zero_shot_correct += 1

        count+=1

    print('accuracy ',zero_shot_correct/count)


file_name=args.input_test_file
dataset = datasetFromJson(file_name)
mcfGreedy(dataset)




