import json, os
from vllm import LLM, SamplingParams
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer
from tqdm import tqdm
import argparse
import numpy as np
from datetime import datetime

# ====================================
#  COT PROMPT
# ====================================

COT_MED_QUESTION_PROMPT_CLOSE = '''
Look the given medical image carefully, and complete the tasks below.
Your task:
1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags.
2. Then provide the correct answer inside <answer>...</answer> tags.
3. No extra information or text outside of these tags.
{Question} 
'''

ZERO_COT_MED_QUESTION_PROMPT = "{Question}" 

MIN_SIZE = 28
def resize_if_needed(image):
    w, h = image.size
    if w < MIN_SIZE or h < MIN_SIZE:
        scale_w = max(w, MIN_SIZE)
        scale_h = max(h, MIN_SIZE)
        # 保持纵横比进行 resize
        ratio = max(MIN_SIZE / w, MIN_SIZE / h)
        new_size = (int(w * ratio), int(h * ratio))
        return image.resize(new_size, Image.BICUBIC)
    return image

class VL_Evaluator():
    def __init__(self, model_name_or_path, max_image_num=2):
        self.processor = AutoProcessor.from_pretrained(model_name_or_path)

        self.model = LLM(
            model=model_name_or_path,
            gpu_memory_utilization=0.9,
            limit_mm_per_prompt={"image": max_image_num},
            enable_prefix_caching=True,
            trust_remote_code=True,
        )
        self.sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            top_k=50,
            max_tokens=768,
        )

        self.model_name_or_path = model_name_or_path

    def eval_batch(self, sample_list, image_dir):
            
        prompts_text_and_vision = []
        for sample in sample_list:
            images = []
            # images
            if isinstance(sample["image"], list):
                for image in sample["image"]:
                    images.append(Image.open(os.path.join(image_dir, image)) if isinstance(image, str) else image)
            else:
                images = [Image.open(os.path.join(image_dir, sample["image"])) if isinstance(sample["image"], str) else sample["image"]]
            
            # texts
            if self.task_name == "geomath" and self.eval_type in ["sft", "zero-shot"]:
                messages = [
                    {
                        "role": "user",
                        "content": [
                            *[{"type": "image"} for _ in images],
                            {"type": "text", "text": self.prompt.format(Question=sample['problem_no_prompt'])},
                        ],
                    }
                ]
            else:
                messages = [
                    {
                        "role": "user",
                        "content": [
                            *[{"type": "image"} for _ in images],
                            {"type": "text", "text": self.prompt.format(Question=sample['problem'])},
                        ],
                    }
                ]

            vllm_prompt = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )

            # merge text and images
            prompts_text_and_vision.append(
                {
                    "prompt": vllm_prompt, 
                    "multi_modal_data": {"image": images}
                }
            )

        outputs = self.model.generate(prompts_text_and_vision, sampling_params=self.sampling_params, use_tqdm=False)

        assert len(outputs) == len(prompts_text_and_vision), f"Out({len(outputs)}) != In({len(prompts_text_and_vision)})"

        for output, item in zip(outputs, sample_list):
            generated_text = output.outputs[0].text
            item["pred"] = generated_text

        return sample_list
    
    def run(self, task_name="trance", eval_type="cot-sft", batch=16):

        assert task_name in ["slake","slake-closed","vqa-rad","vqa-rad-closed","path-vqa","slake-zero","vqa-rad-zero","path-vqa-zero"], \
            f"Task ({task_name}) is not supported. Please choose in ['slake','vqa-rad','path-vqa']"
        
        assert eval_type in ["cot-sft"], f"Type ({eval_type}) is not supported. Please choose in ['cot-sft']"

        self.path_to_save = os.path.join(self.model_name_or_path, "vision-r1-result")
        if not os.path.exists(self.path_to_save):
            os.makedirs(self.path_to_save)

        # Prompt
        if eval_type == "cot-sft":
            if task_name in ["slake"]:
                self.prompt = COT_MED_QUESTION_PROMPT_CLOSE
            elif task_name in ["slake-closed"]:
                self.prompt = COT_MED_QUESTION_PROMPT_CLOSE
            elif task_name in ["vqa-rad"]:
                self.prompt = COT_MED_QUESTION_PROMPT_CLOSE
            elif task_name in ["vqa-rad-closed"]:
                self.prompt = COT_MED_QUESTION_PROMPT_CLOSE
            elif task_name in ["path-vqa"]:
                self.prompt = COT_MED_QUESTION_PROMPT_CLOSE
            elif task_name in ["slake-zero"]:
                self.prompt = ZERO_COT_MED_QUESTION_PROMPT
            elif task_name in ["vqa-rad-zero"]:
                self.prompt = ZERO_COT_MED_QUESTION_PROMPT
            elif task_name in ["path-vqa-zero"]:
                self.prompt = ZERO_COT_MED_QUESTION_PROMPT

        # Path to benchmark
        if task_name == "slake":
            self.benchmark_json = "/home/duyuetian/projects/MedVLM-R1/dataset/slake/test_convert.json"
            self.image_dir = "/sda/duyuetian/dataset/SLAKE/refined/imgs"
        elif task_name == "vqa-rad":
            self.benchmark_json = "/home/duyuetian/projects/MedVLM-R1/dataset/vqa-rad/test_convert.json"
            self.image_dir = "/sda/duyuetian/dataset/VQA-RAD/VQA_RAD_Image_Folder/"
        elif task_name == "path-vqa":
            self.benchmark_json = "/home/duyuetian/projects/MedVLM-R1/dataset/path-vqa/test_convert.json"
            self.image_dir = "/sda/duyuetian/dataset/PATH-VQA/path_vqa/image_test/"
        elif task_name == "slake-zero":
            self.benchmark_json = "/home/duyuetian/projects/MedVLM-R1/dataset/slake/test_convert.json"
            self.image_dir = "/sda/duyuetian/dataset/SLAKE/refined/imgs"
        elif task_name == "vqa-rad-zero":
            self.benchmark_json = "/home/duyuetian/projects/MedVLM-R1/dataset/vqa-rad/test_convert.json"
            self.image_dir = "/sda/duyuetian/dataset/VQA-RAD/VQA_RAD_Image_Folder/"
        elif task_name == "path-vqa-zero":
            self.benchmark_json = "/home/duyuetian/projects/MedVLM-R1/dataset/path-vqa/test_convert.json"
            self.image_dir = "/sda/duyuetian/dataset/PATH-VQA/path_vqa/image_test/"
        
        
        self.task_name = task_name
        self.eval_type = eval_type

        with open(self.benchmark_json, 'r') as file:
            data = json.load(file)

        sample_batch = []
        data_with_pred = []
        pred_times = 0

        now_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        for idx, sample in tqdm(enumerate(data), desc=f"{self.task_name}-{self.eval_type}", total=len(data)):
            sample_batch.append(sample)

            if idx % batch != batch - 1 and idx != len(data) - 1:
                continue
            else:
                sample_batch_with_pred = self.eval_batch(sample_batch, self.image_dir)
                data_with_pred += sample_batch_with_pred
                pred_times += 1
                sample_batch = []

            if pred_times % 10 == 0:
                self.path_to_save = os.path.join(self.model_name_or_path, "vision-r1-result")
                if not os.path.exists(self.path_to_save):
                    os.makedirs(self.path_to_save)
                
                with open(os.path.join(self.path_to_save, f"{self.task_name}_confidence_{now_str}.json"), 'w', encoding='utf-8') as outfile:
                    json.dump(data_with_pred, outfile, indent=4)

        with open(os.path.join(self.path_to_save, f"{self.task_name}_confidence_{now_str}.json"), 'w', encoding='utf-8') as outfile:
            json.dump(data_with_pred, outfile, indent=4)

        print(f"Save to {os.path.join(self.path_to_save, f'{self.task_name}_confidence_{now_str}.json')}")

class QWEN_VL_Evaluator(VL_Evaluator):
    def __init__(self, model_name_or_path, max_image_num=2, min_pixels=3136, max_pixels=480000):
        self.processor = AutoProcessor.from_pretrained(model_name_or_path)
        try:
            self.processor.pad_token_id = self.processor.tokenizer.pad_token_id
            self.processor.eos_token_id = self.processor.tokenizer.eos_token_id
            self.processor.image_processor.max_pixels = max_pixels
            self.processor.image_processor.min_pixels = min_pixels
        except:
            pass
        self.model = LLM(
            model=model_name_or_path,
            gpu_memory_utilization=0.9,
            limit_mm_per_prompt={"image": max_image_num},
            enable_prefix_caching=True,
            trust_remote_code=True,
        )
        self.sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            top_k=50,
            max_tokens=768,
            logprobs=20,
        )

        self.model_name_or_path = model_name_or_path

    def eval_batch(self, sample_list, image_dir):
            
        prompts_text_and_vision = []
        for sample in sample_list:
            images = []
            # images
            if isinstance(sample["image"], list):
                for image in sample["image"]:
                    img = Image.open(os.path.join(image_dir, image)) if isinstance(image, str) else image
                    images.append(resize_if_needed(img))
            else:
                img = Image.open(os.path.join(image_dir, sample["image"])) if isinstance(sample["image"], str) else sample["image"]
                images.append(resize_if_needed(img))

            # texts
            if self.task_name == "geomath" and self.eval_type in ["sft", "zero-shot"]:
                messages = [
                    {
                        "role": "user",
                        "content": [
                            *[{"type": "image"} for _ in images],
                            {"type": "text", "text": self.prompt.format(Question=sample['problem_no_prompt'])},
                        ],
                    }
                ]
            else:
                messages = [
                    {
                        "role": "user",
                        "content": [
                            *[{"type": "image"} for _ in images],
                            {"type": "text", "text": self.prompt.format(Question=sample['problem'])},
                        ],
                    }
                ]
            vllm_prompt = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )

            # merge text and images
            prompts_text_and_vision.append(
                {
                    "prompt": vllm_prompt, 
                    "multi_modal_data": {"image": images}
                }
            )

        outputs = self.model.generate(prompts_text_and_vision, sampling_params=self.sampling_params, use_tqdm=False)

        assert len(outputs) == len(prompts_text_and_vision), f"Out({len(outputs)}) != In({len(prompts_text_and_vision)})"

        # intuitor
        for output, item in zip(outputs, sample_list):
            generated_text = output.outputs[0].text
            generate_tokens = output.outputs[0].token_ids

            logit_probs = output.outputs[0].logprobs

            all_logprobs = []
            selected_probs = []
            kl_values = []
            all_entropies = []
            token_confidence_pairs = []

            tokenizer = self.model.get_tokenizer()

            for prob_dict in logit_probs:
                step_entropy = 0.0
                
                step_logprobs = [logprob_obj.logprob for logprob_obj in prob_dict.values()]
                
                step_probs = np.exp(step_logprobs)

                step_probs = step_probs / np.sum(step_probs)
                
                for p in step_probs:
                    if p > 0:
                        step_entropy -= p * np.log2(p)
                        
                all_entropies.append(step_entropy)

            if all_entropies:
                avg_entropy = sum(all_entropies) / len(all_entropies)
            else:
                avg_entropy = 0.0

            vocab_size = len(logit_probs[0])

            for prob_dict, token in zip(logit_probs, generate_tokens):
                token_prob = np.exp(prob_dict[token].logprob)
                selected_probs.append(np.round(token_prob * 100, 4))

                token_logprob = prob_dict[token].logprob
                all_logprobs.append(token_logprob)

                P = np.array([np.exp(prob.logprob) for prob in prob_dict.values()])
                P = P / np.sum(P)
                Q = np.ones_like(P) / vocab_size

                kl = kl_divergence(Q, P)
                kl_values.append(kl)

                token_text = tokenizer.decode([token], clean_up_tokenization_spaces=True)
                token_confidence = np.round(token_prob * 100, 2)
                token_confidence_pairs.append(f"{token_text},{token_confidence}")

            if all_logprobs: # 避免空列表导致除零错误
                avg_logprob = sum(all_logprobs) / len(all_logprobs)
            else:
                avg_logprob = 0.0

            avg_prob = sum(selected_probs) / len(selected_probs)
            item["pred"] = generated_text
            item["confidence"] = np.round(np.exp(avg_logprob), 4)
            print(item["confidence"])
            self_certainty = np.mean(kl_values)
            item["self_certainty"] = np.round(self_certainty, 4)

            # 可视化
            item["token_confidence"] = ";".join(token_confidence_pairs)
        return sample_list

def kl_divergence(Q, P, eps=1e-12):
    Q = np.array(Q) + eps  # 避免 log(0)
    P = np.array(P) + eps
    return np.sum(Q * np.log(Q / P))   

class Mllama_VL_Evaluator(VL_Evaluator):
    def __init__(self, model_name_or_path, max_image_num=2):
        
        self.model = LLM(
            model=model_name_or_path,
            gpu_memory_utilization=0.8,
            tensor_parallel_size=2,
            max_model_len=4096,
            limit_mm_per_prompt={"image": max_image_num},
            enable_prefix_caching=True,
            trust_remote_code=True,
            max_num_seqs=16,
            enforce_eager=True,
        )
        self.sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            top_k=50,
            max_tokens=768,
        )

        self.model_name_or_path = model_name_or_path

    def eval_batch(self, sample_list, image_dir):
            
        prompts_text_and_vision = []
        for sample in sample_list:
            images = []
            # images
            if isinstance(sample["image"], list):
                for image in sample["image"]:
                    images.append(Image.open(os.path.join(image_dir, image)) if isinstance(image, str) else image)
            else:
                images = [Image.open(os.path.join(image_dir, sample["image"])) if isinstance(sample["image"], str) else sample["image"]]

            # texts
            if self.task_name == "geomath" and self.eval_type in ["sft", "zero-shot"]:
                placeholders = "<|image|>" * len(images)
                prompt = f"{placeholders}<|begin_of_text|>{self.prompt.format(Question=sample['problem_no_prompt'])}"
                
            else:
                placeholders = "<|image|>" * len(images)
                prompt = f"{placeholders}<|begin_of_text|>{self.prompt.format(Question=sample['problem'])}"

            vllm_prompt = prompt

            # merge text and images
            prompts_text_and_vision.append(
                {
                    "prompt": vllm_prompt, 
                    "multi_modal_data": {"image": images}
                }
            )

        outputs = self.model.generate(prompts_text_and_vision, sampling_params=self.sampling_params, use_tqdm=False)

        assert len(outputs) == len(prompts_text_and_vision), f"Out({len(outputs)}) != In({len(prompts_text_and_vision)})"

        for output, item in zip(outputs, sample_list):
            generated_text = output.outputs[0].text
            item["pred"] = generated_text

        return sample_list
    
class PHI3V_VL_Evaluator(VL_Evaluator):
    def __init__(self, model_name_or_path, max_image_num=2):
        
        self.model = LLM(
            model=model_name_or_path,
            gpu_memory_utilization=0.9,
            limit_mm_per_prompt={"image": max_image_num},
            enable_prefix_caching=True,
            trust_remote_code=True,
        )
        self.sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            top_k=50,
            max_tokens=768,
        )

        self.model_name_or_path = model_name_or_path

    def eval_batch(self, sample_list, image_dir):
            
        prompts_text_and_vision = []
        for sample in sample_list:
            images = []
            # images
            if isinstance(sample["image"], list):
                for image in sample["image"]:
                    images.append(Image.open(os.path.join(image_dir, image)) if isinstance(image, str) else image)
            else:
                images = [Image.open(os.path.join(image_dir, sample["image"])) if isinstance(sample["image"], str) else sample["image"]]

            # texts
            if self.task_name == "geomath" and self.eval_type in ["sft", "zero-shot"]:
                placeholders = "\n".join(f"<|image_{i}|>"
                             for i, _ in enumerate(images, start=1))
                prompt = f"<|user|>\n{placeholders}\n{self.prompt.format(Question=sample['problem_no_prompt'])}<|end|>\n<|assistant|>\n"
                
            else:
                placeholders = "\n".join(f"<|image_{i}|>"
                             for i, _ in enumerate(images, start=1))
                prompt = f"<|user|>\n{placeholders}\n{self.prompt.format(Question=sample['problem'])}<|end|>\n<|assistant|>\n"

            vllm_prompt = prompt

            # merge text and images
            prompts_text_and_vision.append(
                {
                    "prompt": vllm_prompt, 
                    "multi_modal_data": {"image": images}
                }
            )

        outputs = self.model.generate(prompts_text_and_vision, sampling_params=self.sampling_params, use_tqdm=False)

        assert len(outputs) == len(prompts_text_and_vision), f"Out({len(outputs)}) != In({len(prompts_text_and_vision)})"

        for output, item in zip(outputs, sample_list):
            generated_text = output.outputs[0].text
            item["pred"] = generated_text

        return sample_list
    

class Pixtral_VL_Evaluator(VL_Evaluator):
    def __init__(self, model_name_or_path, max_image_num=2):
        
        self.model = LLM(
            model=model_name_or_path,
            gpu_memory_utilization=0.9,
            limit_mm_per_prompt={"image": max_image_num},
            enable_prefix_caching=True,
            trust_remote_code=True,
        )
        self.sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            top_k=50,
            max_tokens=768,
        )

        self.model_name_or_path = model_name_or_path

    def eval_batch(self, sample_list, image_dir):
            
        prompts_text_and_vision = []
        for sample in sample_list:
            images = []
            # images
            if isinstance(sample["image"], list):
                for image in sample["image"]:
                    images.append(Image.open(os.path.join(image_dir, image)) if isinstance(image, str) else image)
            else:
                images = [Image.open(os.path.join(image_dir, sample["image"])) if isinstance(sample["image"], str) else sample["image"]]

            # texts
            if self.task_name == "geomath" and self.eval_type in ["sft", "zero-shot"]:
                placeholders = "[IMG]" * len(images)
                prompt = f"<s>[INST]{self.prompt.format(Question=sample['problem_no_prompt'])}\n{placeholders}[/INST]"
                
            else:
                placeholders = "[IMG]" * len(images)
                prompt = f"<s>[INST]{self.prompt.format(Question=sample['problem'])}\n{placeholders}[/INST]"

            vllm_prompt = prompt

            # merge text and images
            prompts_text_and_vision.append(
                {
                    "prompt": vllm_prompt, 
                    "multi_modal_data": {"image": images}
                }
            )

        outputs = self.model.generate(prompts_text_and_vision, sampling_params=self.sampling_params, use_tqdm=False)

        assert len(outputs) == len(prompts_text_and_vision), f"Out({len(outputs)}) != In({len(prompts_text_and_vision)})"

        for output, item in zip(outputs, sample_list):
            generated_text = output.outputs[0].text
            item["pred"] = generated_text

        return sample_list


class Internvl_VL_Evaluator(VL_Evaluator):
    def __init__(self, model_name_or_path, max_image_num=2):

        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
        
        self.model = LLM(
            model=model_name_or_path,
            gpu_memory_utilization=0.9,
            limit_mm_per_prompt={"image": max_image_num},
            enable_prefix_caching=True,
            trust_remote_code=True,
        )
        self.sampling_params = SamplingParams(
            temperature=0.1,
            top_p=0.9,
            top_k=50,
            max_tokens=768,
        )

        self.model_name_or_path = model_name_or_path

    def eval_batch(self, sample_list, image_dir):
            
        prompts_text_and_vision = []
        for sample in sample_list:
            images = []
            # images
            if isinstance(sample["image"], list):
                for image in sample["image"]:
                    images.append(Image.open(os.path.join(image_dir, image)) if isinstance(image, str) else image)
            else:
                images = [Image.open(os.path.join(image_dir, sample["image"])) if isinstance(sample["image"], str) else sample["image"]]

            # texts
            if self.task_name == "geomath" and self.eval_type in ["sft", "zero-shot"]:
                placeholders = "\n".join(f"Image-{i}: <image>\n"
                             for i, _ in enumerate(images, start=1))
                messages = [{'role': 'user', 'content': f"{placeholders}\n{self.prompt.format(Question=sample['problem_no_prompt'])}"}]
                
            else:
                placeholders = "\n".join(f"Image-{i}: <image>\n"
                             for i, _ in enumerate(images, start=1))
                messages = [{'role': 'user', 'content': f"{placeholders}\n{self.prompt.format(Question=sample['problem'])}"}]

            vllm_prompt = self.tokenizer.apply_chat_template(messages,
                                           tokenize=False,
                                           add_generation_prompt=True)

            # merge text and images
            prompts_text_and_vision.append(
                {
                    "prompt": vllm_prompt, 
                    "multi_modal_data": {"image": images}
                }
            )

        outputs = self.model.generate(prompts_text_and_vision, sampling_params=self.sampling_params, use_tqdm=False)

        assert len(outputs) == len(prompts_text_and_vision), f"Out({len(outputs)}) != In({len(prompts_text_and_vision)})"

        for output, item in zip(outputs, sample_list):
            generated_text = output.outputs[0].text
            item["pred"] = generated_text

        return sample_list


if __name__ == "__main__":

    # Define the argument parser
    parser = argparse.ArgumentParser(description="Evaluate a model on different benchmarks with specified strategies.")
    parser.add_argument('--batch_size', type=int, default=16, help="Batch size for evaluation (default: 16)")
    parser.add_argument('--model_name_or_path', type=str, required=True, help="Path to the model checkpoint.")
    parser.add_argument('--benchmark_list', type=str, nargs='+', help="List of benchmarks to evaluate on.")
    parser.add_argument('--stratage_list', type=str, nargs='+', help="List of strategies for each benchmark.")

    # Parse the arguments
    args = parser.parse_args()

    print(f"Benchmark List: {args.benchmark_list}")
    print(f"Stratage List: {args.stratage_list}")

    print(f"Loading Model Path from {args.model_name_or_path} ...")
    if 'qwen' in args.model_name_or_path.lower():
        print("======== Using QWEN_VL_Evaluator ==========")
        evaluator = QWEN_VL_Evaluator(args.model_name_or_path)
    elif 'llava' in args.model_name_or_path.lower():
        print("======== Using Mllama_VL_Evaluator ==========")
        evaluator = Mllama_VL_Evaluator(args.model_name_or_path)
    elif 'phi' in args.model_name_or_path.lower():
        print("======== Using PHI3V_VL_Evaluator ==========")
        evaluator = PHI3V_VL_Evaluator(args.model_name_or_path)
    elif 'internvl' in args.model_name_or_path.lower():
        print("======== Using Internvl_VL_Evaluator ==========")
        evaluator = Internvl_VL_Evaluator(args.model_name_or_path)
    elif 'pixtral' in args.model_name_or_path.lower():
        print("======== Using Pixtral_VL_Evaluator ==========")
        evaluator = Pixtral_VL_Evaluator(args.model_name_or_path)
    else:
        print("======== Using Default VL_Evaluator ==========")
        evaluator = VL_Evaluator(args.model_name_or_path)


    for benchmark, stratage in zip(args.benchmark_list, args.stratage_list):
        print(f"================== Evaluating {benchmark}-{stratage} ==================")
        evaluator.run(benchmark, stratage, args.batch_size)