"""
The script for evaluation of finetuning result on different evaluation datasets,
optimized with vLLM for faster generation.
"""
import torch
import os
from vllm import LLM, SamplingParams
# from vllm import LoraRequest
from modelscope import AutoTokenizer
from modelscope.msdatasets import MsDataset
from tqdm import tqdm
from datetime import datetime
import json
import argparse

class PromptComposer:
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path

    def compose_examinee_prompt(self, data_test):
        if 'gsm8k' in self.dataset_path.lower():
            template = "Solve the following math problem step by step, and give the final answer clearly. \
                Question: {question} \n\n Answer: "
            message_batched = [template.format(question=question) for question in data_test['question']]
        elif 'pubmedqa' in self.dataset_path.lower():
            template = "Use the information from context and answer the following medical question step by step, and give the final answer clearly. \
                Question: {question} \n\n Context: {context} \n\n Answer: "
            message_batched = [template.format(question=question, context="".join(context_dict["contexts"])) for question, context_dict in zip(data_test['question'], data_test['context'])]
        else:
            raise ValueError(f"Unsupported dataset path for prompt composition: {self.dataset_path}")
            
        return message_batched

    def compose_judge_prompt(self, data_test, examinee_response):
        if 'gsm8k' in self.dataset_path.lower():
            sys_message = "You are a impartial grader. Your job is to decide if a model's answer to a math problem is correct. \n\n \
                Rules: 1. Compare the model's final answer with the reference solution's final numeric result. \
                2. if the model's final answer matches the reference, mark it as correct, otherwise mark it as incorrect. \
                3. Out put correct/wrong at the beginning, and then give a short reason (less than 20 words) \
                IMPORTANT: give correct judgement if the final answer of model's answer is correct. "
            user_template = "Question: {question} \n\n Model's answer: {model_answer} \n\n Reference answer: {reference} \n\n"
            batched_user_message = [
                user_template.format(question=data_test['question'][i], model_answer=examinee_response[i],
                reference=data_test['answer'][i])\
                    for i in range(len(examinee_response))]
        elif 'pubmedqa' in self.dataset_path.lower():
            sys_message = "You are a impartial grader. Your obb is to decide if a model's answer to a medical question is correct. \n\n\
                Rules: 1. Compare the model's answer with the reference solution's final numeric result. \
                2. if the model's final answer matches the reference and give reasonable reasons compared to reference reason, mark it as correct, otherwise mark it as incorrect. \
                3. Out put correct/wrong at the beginning, and then give a short reason (less than 20 words)"
            user_template = "Question: {question} \n\n Model's answer: {model_answer} \n\n Reference answer: {reference} \n\n Reference reason: {reason} \n\n"
            batched_user_message = [
                user_template.format(
                    question=data_test['question'][i],
                    model_answer=examinee_response[i],
                    reference=data_test['final_decision'][i],
                    reason=data_test['long_answer'][i]) \
                        for i in range(len(examinee_response))]
        else:
            raise ValueError(f"Unsupported dataset path for prompt composition: {self.dataset_path}")

        batched_messages = [
            [
                {"role": "system", "content": sys_message},
                {"role": "user", "content": user_message}
            ]
            for user_message in batched_user_message]
            
        return batched_messages

    def get_label(self, batched_response, batch_size=1):
        # This function expects batched_response to be a list of conversation histories.
        return ['correct' in response[-1]['content'][:10].lower() for response in batched_response]

def main():
    parser = argparse.ArgumentParser(description="The script for evaluation of finetuning result on defferent evaluation datasets.")
    parser.add_argument("--judge_model_name", type=str, required=True, help="The model used for evaluation.")
    parser.add_argument("--gpu", type=str, default="0", help="Comma-separated list of GPU ids to use for tensor parallelism, e.g., '0,1,2,3'.")
    parser.add_argument("--batch_size", type=int, default=16, help="The batch size for inference. Adjust based on your GPU memory.")
    parser.add_argument("--examinee_model_list", type=json.loads, required=True, help="A JSON list of model paths to be evaluated, e.g., '[\"/path/to/model1\", \"/path/to/model2\"]'")
    parser.add_argument("--base_model_name", type=str, required=True, help="The base model used for the examinees.")
    parser.add_argument("--test_data_path", type=str, required=True, help="The path of the test data.")
    parser.add_argument("--test_size_limit", type=int, default=300, help="The size of test data.")
    parser.add_argument("--test_subset", type=str, required=False, help="The subset of test data.")
    parser.add_argument("--test_split", type=str, default='test', help="The split of the test dataset.")
    parser.add_argument("--output_path", type=str, required=True, help="The path of the output file.")
    parser.add_argument("--max_new_length_examinee", type=int, default=512, help="The max new length of examinee.")
    parser.add_argument("--max_new_length_judge", type=int, default=512, help="The max new length of judge.")
    # New argument for the summary file
    parser.add_argument("--summary_file", type=str, required=False, help="Optional file path to store the final accuracy summary for each model.")
    parser.add_argument("--mode", type=str, default='list', help="The mode of input mode.")
    parser.add_argument("--models_folder", type=str, default=None, help="The folder containing models.")

    args = parser.parse_args()
    
    # Set up GPU for vLLM tensor parallelism
    num_gpus = len(args.gpu.split(','))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    
    # create output directory
    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path, exist_ok=True)
    
    # load the test data
    test_dataset = MsDataset.load(args.test_data_path, subset_name=args.test_subset, split=args.test_split)
    if not hasattr(test_dataset, 'select'):
        test_dataset = test_dataset.to_hf_dataset()
    test_dataset = test_dataset.select(range(args.test_size_limit))
    
    # Load the judge model using vLLM
    print(f"Loading judge model: {args.judge_model_name}")
    judge_model = LLM(model=args.judge_model_name, tensor_parallel_size=num_gpus, gpu_memory_utilization=0.65,trust_remote_code=True)
    judge_tokenizer = AutoTokenizer.from_pretrained(args.judge_model_name, trust_remote_code=True, padding_side="left")
    judge_sampling_params = SamplingParams(temperature=0.0, max_tokens=args.max_new_length_judge)
    
    # Load the examinee tokenizer
    # examinee_tokenizer = AutoTokenizer.from_pretrained(args.base_model_name, trust_remote_code=True, padding_side="left")
    # Define prompt composer
    prompt_composer = PromptComposer(args.test_data_path)
    
    examinee_list = args.examinee_model_list if args.mode == 'list' \
        else [os.path.join(args.models_folder, model_name) for model_name in os.listdir(args.models_folder)]

    for examinee_model_name in examinee_list:
        # Load the examinee model using vLLM
        print(f"Evaluating examinee model: {examinee_model_name}")
        # if 'lora' in examinee_model_name.lower() or os.path.exists(os.path.join(examinee_model_name,'adapter_config.json')):
        #     base_model = LLM(model=args.base_model_name, tensor_parallel_size=num_gpus, trust_remote_code=True)
        #     examinee_model = PeftModel.from_pretrained(
        #         base_model, 
        #         examinee_model_name, 
        #         enable_lora = True,
        #         lora_modules = {"my_adapter": examinee_model_name})
        # else:
        examinee_model = LLM(model=examinee_model_name, tensor_parallel_size=num_gpus, gpu_memory_utilization=0.15,trust_remote_code=True)
        examinee_sampling_params = SamplingParams(temperature=0.0, max_tokens=args.max_new_length_examinee)
    
        # store file name
        file_name = f"eval_{args.test_data_path.rsplit('/',1)[-1]}_{examinee_model_name.rsplit('/', 1)[-1]}.jsonl"
    
        # Initialize counters for accuracy calculation
        correct_predictions = 0
        total_samples = 0
        
        with open(os.path.join(args.output_path, file_name), 'w') as f:
            # Process the dataset in batches
            for i in tqdm(range(0, len(test_dataset), args.batch_size), desc=f"Evaluating {examinee_model_name.rsplit('/', 1)[-1]}"):
                batched_samples = test_dataset[i:min(len(test_dataset), i + args.batch_size)]
                
                # 1. Generate response from the examinee model
                batched_message_examinee = prompt_composer.compose_examinee_prompt(batched_samples)
                
                examinee_outputs = examinee_model.generate(batched_message_examinee, examinee_sampling_params)
                batched_examinee_response = [output.outputs[0].text.strip() for output in examinee_outputs]

                # 2. Generate judgment from the judge model
                batched_message_judge = prompt_composer.compose_judge_prompt(batched_samples, batched_examinee_response)
                
                batched_judge_prompts_str = [
                    judge_tokenizer.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
                    for conv in batched_message_judge
                ]
                
                judge_outputs = judge_model.generate(batched_judge_prompts_str, judge_sampling_params)
                batched_judge_generated_text = [output.outputs[0].text.strip() for output in judge_outputs]
                
                # 3. Get labels from the judgment
                full_conversations_for_labeling = []
                for idx, conversation in enumerate(batched_message_judge):
                    new_conversation = conversation.copy()
                    new_conversation.append({"role": "assistant", "content": batched_judge_generated_text[idx]})
                    full_conversations_for_labeling.append(new_conversation)

                batched_label = prompt_composer.get_label(full_conversations_for_labeling, batch_size=len(batched_samples['question']))
                
                # Update accuracy counters
                correct_predictions += sum(batched_label)
                total_samples += len(batched_label)
                
                # 4. Write results to file
                for j in range(len(batched_message_examinee)):
                    result = {
                        "prompt": batched_message_examinee[j],
                        "examinee_response": batched_examinee_response[j],
                        "judge_response": batched_judge_generated_text[j],
                        "label": batched_label[j]
                        }
                    f.write(json.dumps(result, ensure_ascii=False) + '\n')
        
        # Calculate and report final accuracy for the current model
        accuracy = correct_predictions / total_samples if total_samples > 0 else 0
        print(f"Final Accuracy for {examinee_model_name}: {accuracy:.4f} ({correct_predictions}/{total_samples})")
        
        # Save the summary result if the file path is provided
        if args.summary_file:
            summary_result = {
                "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                "model_name": examinee_model_name,
                "dataset": args.test_data_path,
                "accuracy": f"{accuracy:.4f}",
                "correct": correct_predictions,
                "total": total_samples
            }
            with open(args.summary_file, 'a') as sf:
                sf.write(json.dumps(summary_result) + '\n')
                    
        # Clean up the examinee model from memory before loading the next one
        del(examinee_model)
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()