import pandas as pd
import re
import argparse
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from vllm.outputs import RequestOutput
from typing import List
import torch
from huggingface_hub import login
import random
import numpy as np
import torch
from transformers import set_seed

def seed_everything(seed: int = 111):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    set_seed(seed)



def generate_extract_prompt_with_question(
    extract_template: str,
    answer_text: str,
    tokenizer
) -> str:
    """
    Generate a prompt to be passed to the 2nd model.
    - question_text: original question
    - answer_text: The long answer generated by the 1st model (Chain of Thought, etc. possible)
    - tokenizer: Used in apply_chat_template, etc.
    """
    
    # (1) Construct prompt_text including question and answer
    filled_prompt = extract_template.format(
        answer_text=answer_text.strip()
    )

    # (2) Apply chat template (e.g., tokenizer.apply_chat_template)
    messages = [
        {"role": "user", "content": filled_prompt}
    ]
    prompt_for_model = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # (3) Add "Model's Final Answer is: " to the end
    prompt_for_model += "\n- **Model's Final Answer is:** "

    return prompt_for_model




def parse_args():    
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)  # model path
    parser.add_argument("--data_file", type=str, default='./data/gsm8k_test')  # data_path
    parser.add_argument("--tensor_parallel_size", type=int, default=4)  # tensor_parallel_size
    return parser.parse_args()


def main():
    args = parse_args()
    df=pd.read_csv(args.data_file)
    
    
    model_name = 'meta-llama/Llama-3.1-8B-Instruct'
    model = LLM(
        model_name,
        tensor_parallel_size=args.tensor_parallel_size,
        max_model_len=16392,
        trust_remote_code=True,
        gpu_memory_utilization=0.95,
        dtype='auto',
        enforce_eager=True)

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        padding_side="left")

    template_path='./template/extract_answer_med.txt'
    
    with open(template_path, 'r',encoding='utf-8') as file:
        extract_template = file.read()  
        
    lst= []
    for i in range(len(df)):
        prompt=generate_extract_prompt_with_question(extract_template,df['answer'][i],tokenizer)
        lst.append(prompt)
        
    df['prompts']=lst
    
    sampling_params = SamplingParams(temperature=0.9, top_p=0.7, top_k=10,repetition_penalty=1, max_tokens=100)
    
    outputs: List[RequestOutput] = model.generate(lst, sampling_params)
    
    generated_texts = [output.outputs[0].text for output in outputs]
    
    df['gen_answer'] = generated_texts
    
    count=0
    for i in range(len(df)):
        if generated_texts[i].split('.')[0].strip()==df['GT'][i].strip():
            count+=1
            
        else:
            pass
    
    print('='*80)
    print(f'ACC of {args.model}:')
    print('ACC:',count/len(df))
    print('='*80)
    
    
    
    df.to_csv(args.data_file,index=False)
    
    

if __name__ == "__main__":
    seed_everything(666)
    main()