from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

from pathlib import Path

from PIL import Image

import torch

import gc

from third_party import hpsv2

from util.yaml_util import load_yaml
from util.pipeline_util import img_latent_to_pil

from .reward_model import RewardModel


class HumanPreferenceScore_v2_RewardModel(RewardModel):
    def __init__(
        self, 

        # ---------= [Pipeline] =---------
        pipeline: Optional["StableDiffusionPipeline"] = None, 
        num_inference_step: Optional[int] = None, 

        # ---------= [Param] =---------
        prompt_emb_list: List[torch.Tensor] = None, 
        param_dict: Optional[Dict] = None, 
        num_sample_per_prompt: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 
        reward_dtype: Optional[str] = "float32", 
        offload_to_cpu: Optional[bool] = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size: Optional[int] = 1, 
        cal_intermediate_reward_batch_size: Optional[int] = 1, 
        cal_final_reward_batch_size: Optional[int] = 1, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy: str = "disabled", 
        # potential_exp_growing: bool = False, 
        # potential_exp_base: float = 1.0, 
        cal_intermediate_reward_policy: Optional[str] = "sequential", 

        device: Optional[str] = "cpu", 

        vae_decode_batch_size: Optional[int] = 10, 

        **arg_dict
    ):
        super().__init__(
            # ---------= [Pipeline] =---------
            pipeline = pipeline, 
            num_inference_step = num_inference_step, 

            # ---------= [Param] =---------
            prompt_emb_list = prompt_emb_list, 
            param_dict = param_dict, 
            num_sample_per_prompt = num_sample_per_prompt, 

            # ---------= [Reward] =---------
            reward_shape = reward_shape, 
            reward_dtype = reward_dtype, 
            offload_to_cpu = offload_to_cpu, 

            # ---------= [Parallel] =---------
            cal_dynamics_batch_size = cal_dynamics_batch_size, 
            cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
            cal_final_reward_batch_size = cal_final_reward_batch_size, 

            # ---------= [Reward Shaping] =---------
            reward_shaping_policy = reward_shaping_policy, 
            # potential_exp_growing = potential_exp_growing, 
            # potential_exp_base = potential_exp_base, 
            cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

            device = device, 

            vae_decode_batch_size = vae_decode_batch_size
        )

        self._cfg_yaml_path = Path("./config/reward_model/hps_v2.yaml")
        self._model_dict = load_yaml(self._cfg_yaml_path)

        # ---------= [Prepare Everything] =---------
        self._hps_model_ckpt_path = None
        self._vit_model_ckpt_path = None
        
        self._prepare_everything()

        # `__init__()` done
        pass

    
    def _prepare_everything(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """

        hps_model_ckpt_path = self._model_dict["hps_v2"]["hps_model_ckpt_path"]
        vit_model_ckpt_path = self._model_dict["hps_v2"]["vit_model_ckpt_path"]

        self._hps_model_ckpt_path = hps_model_ckpt_path
        self._vit_model_ckpt_path = vit_model_ckpt_path

        self.clip_range = self._model_dict["hps_v2"]["clip_range"]

        self.norm_constant = self._model_dict["hps_v2"]["norm_constant"]
        self.bias_constant = self._model_dict["hps_v2"]["bias_constant"]
        
        # `_prepare_everything()` done
        pass


    @torch.no_grad()
    def batch_cal_hps_v2(
        self, 

        img_pil_list: Union[Image.Image, List[Image.Image]], 
        prompt_list: Union[str, List[str]]
    ) -> List[float]:
        """
        Func:
            Compute the Human Preference Score v2 (HPS v2) for a list of RGB images `img_pil_list`. 

        Ret:
            `hps_v2_list` (float): The list of the derived HPS v2 scores. 
        """

        if not isinstance(img_pil_list, list):
            img_pil_list = [img_pil_list]

        hps_v2_list = hpsv2.score(
            img_pil_list, 
            prompt = prompt_list, 

            hps_model_ckpt_path = self._hps_model_ckpt_path, 
            ViT_model_ckpt_path = self._vit_model_ckpt_path, 
            hps_version = "v2.1"
        )

        # ---------= [Scaling] =---------
        hps_v2_list = [
            score * 100.0 \
                for score in hps_v2_list
        ]

        # `cal_hps_v2()` done
        return hps_v2_list


    @torch.no_grad()
    def cal_final_reward_implement(
        self, 

        sample_idx_list: List[int], 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """

        assert (img_pil_list is not None)

        # ---------= [Prepare Prompt List] =---------
        num_img = len(img_pil_list)

        if isinstance(prompt_list, str):
            prompt_list = [prompt_list]

        num_prompt = len(prompt_list)
        if num_prompt == 1:
            prompt_list = prompt_list * num_img
        elif num_prompt != num_img:
            raise ValueError(
                f"The length of `prompt_list` does not match the length of `img_pil_list`, "
                f"got `{num_prompt}` and `{num_img}`. "
            )

        # ---------= [Compute Final Reward] =---------
        final_reward_list = []

        l = 0

        while l < num_img:
            prompt = prompt_list[l]

            r = l
            while (r + 1 < num_img) and (prompt_list[r + 1] == prompt):
                r += 1
            
            tmp_img_pil_list = img_pil_list[l: (r + 1)]

            length = r - l + 1
            batch_size = self._cal_final_reward_batch_size

            for i in range(0, length, batch_size):
                batch_tmp_pil_list = tmp_img_pil_list[i: min(i + batch_size, length)]
                
                batch_tmp_reward_list = self.batch_cal_hps_v2(
                    img_pil_list = batch_tmp_pil_list, 
                    prompt_list = prompt
                )

                final_reward_list += batch_tmp_reward_list

                # ---------= [Clean Up] =---------
                del batch_tmp_pil_list
                gc.collect()
                torch.cuda.empty_cache()

                # goto `for i`
                pass

            # tmp_reward_list = self.batch_cal_hps_v2(
            #     img_pil_list = tmp_img_pil_list, 
            #     prompt_list = prompt
            # )

            # final_reward_list += tmp_reward_list

            l = r + 1

            # goto `while l`
            pass

        # final_reward_list.shape = (num_img, *reward_shape)
        final_reward_list = torch.tensor(
            final_reward_list, 

            dtype = self._reward_dtype, 
            device = self._device
        )
        
        # ---------= [Clean Up] =---------
        del tmp_img_pil_list
        gc.collect()
        torch.cuda.empty_cache()

        # `cal_final_reward_implement()` done
        return final_reward_list
