import warnings
# 忽略所有 warning
warnings.filterwarnings("ignore")

import os
import json
import tempfile
import shutil
from typing import Union, Dict, Any
from PIL import Image
import numpy as np
from argparse import Namespace

from rewards.MixedReward.unified_reward import UnifiedReward
from rewards.MixedReward.reward_gdino    import GDino
from rewards.MixedReward.score_git       import ColorEvaluator


class MixedReward:
    """
    Three‐way fusion for single‐image & prompt:
      • GroundDINO (GDino)        weight 0.4
      • ColorEvaluator (GIT)      weight 0.4
      • UnifiedReward (1..5→0..1) weight 0.2
    """

    def __init__(
        self,
        git_ckpt_path: str,
        unified_model_path: str,
        gdino_ckpt_path: str,
        gdino_config_path: str,
        device: str = "cuda:0"
    ):
        # 1) GIT evaluator
        self.git_evaluator    = ColorEvaluator(git_ckpt_path)
        # 2) UnifiedReward
        self.unified_reward   = UnifiedReward(unified_model_path, device=device)
        # 3) GDino
        gdino_args = Namespace(
            gdino_ckpt_path   = gdino_ckpt_path,
            gdino_config_path = gdino_config_path
        )
        self.gdino = GDino(gdino_args)
        self.gdino.load_to_device(device)


    def get_reward(
        self,
        image: Union[str, Image.Image],
        solution: Union[str, Dict[str, Any]]
    ) -> float:
        # --- 1) normalize image to filesystem path ---
        temp_img = None
        if isinstance(image, Image.Image):
            fd, temp_img = tempfile.mkstemp(suffix=".png")
            os.close(fd)
            image.save(temp_img)
            image_path = temp_img
        else:
            image_path = image

        # --- 2) parse solution metadata ---
        metadata = json.loads(solution) if isinstance(solution, str) else solution

        try:
            # --- 3) GIT color score (0..1) ---
            pairs     = self.extract_color_objects(metadata)
            git_res   = self.git_evaluator.evaluate_color(image_path, pairs)
            git_score = git_res.get("avg_score", 0.0)
            print(f"git_score = {git_score}")
            # --- 4) UnifiedReward raw score (1..5) → map to [0..1] ---
            raw_uni   = self.unified_reward.get_reward(image, metadata)
            print(f"raw_uni = {raw_uni}")
            uni_score = np.clip((raw_uni - 1.0) / 4.0, 0.0, 1.0)
            print(f"uni_score = {uni_score}")
            # --- 5) GDino spatial/object score (0..1) ---
            gdino_score = self._evaluate_gdino(image, metadata)
            print(f"gdino_score = {gdino_score}")
            # --- 6) weighted sum ---
            if metadata.get("tag") in ("colors", "color_attr"):
                # numeracy task: use 0.5 weight for GIT
                final_score = (
                    0.4 * gdino_score +
                    0.4 * git_score   +
                    0.2 * uni_score
                )
            else:
                final_score = (
                    0.6 * gdino_score +
                    0.2 * git_score   +
                    0.2 * uni_score
                )
            return float(final_score-1.0)

        finally:
            if temp_img and os.path.exists(temp_img):
                os.remove(temp_img)


    def extract_color_objects(self, metadata: Dict[str, Any]):
        """
        From metadata['include'] build list of (color, class, count).
        """
        return [
            (item.get("color", ""),
             item.get("class", ""),
             item.get("count", 1))
            for item in metadata.get("include", [])
        ]


    def _evaluate_gdino(self, image: Union[str, Image.Image], metadata: Dict[str, Any]) -> float:
        """
        Wraps GroundDINO for one image:
         - spatial if any 'position'
         - numeracy if any 'count'
         - else object
        """
        # normalize image
        temp = None
        if isinstance(image, Image.Image):
            fd, temp = tempfile.mkstemp(suffix=".png")
            os.close(fd)
            image.save(temp)
            image_path = temp
        else:
            image_path = image

        include      = metadata.get("include", [])
        nouns        = [itm["class"] for itm in include]
        text_prompt, token_spans = self.gdino.make_prompt(nouns)

        # pick task
        if any("position" in itm for itm in include):
            pos     = next(itm for itm in include if "position" in itm)
            others  = [o for o in include if o is not pos]
            obj1    = pos["class"]
            obj2    = others[0]["class"] if others else nouns[0]
            locality= pos["position"][0]
            task_type     = ["spatial"]
            spatial_info  = [{"obj1": obj1, "obj2": obj2, "locality": locality}]
            numeracy_info = None

        elif any("count" in itm for itm in include):
            task_type     = ["numeracy"]
            # collect all counts into a single list
            all_counts    = [
                {"obj_name": itm["class"], "num": itm["count"]}
                for itm in include if "count" in itm
            ]
            numeracy_info = [ all_counts ]
            spatial_info  = None

        else:
            task_type     = ["object"]
            numeracy_info = None
            spatial_info  = None

        # prepare batch-of-1
        imgs_batch    = [ Image.open(image_path).convert("RGB") ]
        prompts_batch = [ text_prompt ]
        det_prompts   = [{"text_prompt":text_prompt, "token_spans":token_spans}]
        nouns_list    = [ nouns ]

        scores = self.gdino(
            prompts       = prompts_batch,
            images        = imgs_batch,
            task_type     = task_type,
            nouns         = nouns_list,
            det_prompt    = det_prompts,
            numeracy_info = numeracy_info,
            spatial_info  = spatial_info
        )

        # cleanup
        if temp and os.path.exists(temp):
            os.remove(temp)
        return float(scores[0])

    def judge_answer(self, image, data):
        reward_score = self.get_reward(image, data)
        if reward_score < -0.1:
            return False
        else:
            return True