import os
os.environ["CUDA_VISIBLE_DEVICES"] = '4,5,6,7'
import numpy as np
from tqdm import tqdm
import pandas as pd
import datasets
import vllm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import gc
import torch
from vllm.outputs import RequestOutput
from vllm.distributed.parallel_state import destroy_model_parallel
import datasets
from typing import List
from huggingface_hub import login


teacher_path = "meta-llama/Llama-3.3-70B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(teacher_path)


with open('./template/med_annotation.txt', 'r', encoding='utf-8') as file:
    template = file.read()
    

def generate_prompt(question):
    
    global tokenizer
    global template
    
    messages = [
        {"role": "user", "content": template.format(Q=question)}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    text+= '- **Answer**: '
    return text



if __name__ == "__main__":

    gpu_num = 4
    max_token = 16384
    
    # Model initialization
    model = LLM(
        model=teacher_path,
        tensor_parallel_size=gpu_num,
        max_model_len=max_token,
        trust_remote_code=True,
        gpu_memory_utilization=0.98,
        dtype='auto',
        enforce_eager=True
    )

    sampling_params = SamplingParams(temperature=0.9, top_p=0.7, top_k=10,repetition_penalty=1, max_tokens=8192)
    print('Model WAS PREPARED')
    # Load dataset
    df= pd.read_csv('./medqa_generated_final_step_Q.csv')

    
    df.reset_index(drop=True)

    
    quesiton=  df['question'].tolist()
    prompts = list(map(generate_prompt, quesiton))
    batch_size = 10000

    generated_prompts = []
    generated_texts = []

    for batch_start in tqdm(range(0, len(prompts), batch_size)):
        # Get batch prompts
        batch_prompts = prompts[batch_start:batch_start + batch_size]
        
        # Generate outputs for the batch
        outputs: List[RequestOutput] = model.generate(batch_prompts, sampling_params)
        
        # Extract prompts and generated texts from outputs
        batch_extracted_prompts = [output.prompt for output in outputs]
        batch_extracted_generated_texts = [output.outputs[0].text for output in outputs]
        
        generated_prompts.extend(batch_extracted_prompts)
        generated_texts.extend(batch_extracted_generated_texts)
        torch.cuda.empty_cache()


    df['Answer'] = generated_texts

    df.to_csv('medqa_generated_final_step_QA.csv', index=False)

    print('Successfully generated the answers!')
    
    