import pickle

import numpy as np
import torch
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from datasets import load_dataset


def format_example(example):

    prompt = 'The following are multi choice questions. Give ONLY the correct option, no other words or explanation:\n'

    question = example['question']
    label = example['label']
    answer = example['answer']
    text = example['choices']

    prompt += ('Question: ' + question + '\n')
    for i in range(len(text)):
        prompt += label[i] + ': ' + text[i] + '\n'
    prompt += 'Answer: '

    return prompt, answer

def generate(model_type='pre'):

    # Load Model
    if model_type == 'pre':
        model = LLM('Qwen/Qwen2.5-7B', tensor_parallel_size=2,
                    gpu_memory_utilization=0.8, max_model_len=2048, guided_decoding_backend="xgrammar")
        tokenizer = model.get_tokenizer()
    else:
        model = LLM('Qwen/Qwen2.5-7B-Instruct', tensor_parallel_size=2,
                    gpu_memory_utilization=0.8, max_model_len=2048, guided_decoding_backend="xgrammar")
        tokenizer = model.get_tokenizer()

    # Load Datasets
    ds = load_dataset("openlifescienceai/medmcqa")
    cal_dataset = ds["validation"]

    choices = ['A', 'B', 'C', 'D']
    reformat = lambda x: {
        'question': x['question'],
        'choices': [x['opa'], x['opb'], x['opc'], x['opd']],
        'answer': choices[x['cop']],
        'label': choices,
    }
    datasets = [reformat(d) for d in cal_dataset]

    results = {"question": [],
               "predicted_answer": [],
               "correct_answer": [],
               "logits": [],
               "is_correct": [],
               }

    label_list = ['A', 'B', 'C', 'D']

    guided_decoding_params = GuidedDecodingParams(choice=label_list)
    sampling_params = SamplingParams(guided_decoding=guided_decoding_params, logprobs=4, max_tokens=1)

    indices = len(datasets)
    for i in tqdm(range(indices)):

        example = datasets[i]
        input_text, label = format_example(example)

        outputs = model.generate(
            prompts=input_text,
            sampling_params=sampling_params,
            use_tqdm=False
        )

        outputs_dict = outputs[0].outputs[0].logprobs[0]

        tokens_in_order = ['A', 'B', 'C', 'D']

        token_ids = {token: next(key for key, value in outputs_dict.items() if value.decoded_token == token)
                     for token in tokens_in_order}

        logits = [outputs[0].outputs[0].logprobs[0][id].logprob for id in token_ids.values()]
        preds = label_list[np.argmax(logits)]

        results['question'].append(input_text)
        results['is_correct'].append(np.array(preds == label))
        results['logits'].append(logits)
        results['predicted_answer'].append(preds)
        results['correct_answer'].append(label)

    with open(f'prediction_results.pkl', "wb") as f:
        pickle.dump(results, f)

    torch.cuda.empty_cache()
    del model
    del tokenizer

    return results
