from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

import gc

from pathlib import Path

from PIL import Image

import numpy as np

import torch

import ImageReward

from util.yaml_util import load_yaml
from util.pipeline_util import img_latent_to_pil

from .reward_model import RewardModel


class ImageReward_RewardModel(RewardModel):
    def __init__(
        self, 

        # ---------= [Pipeline] =---------
        pipeline: Optional["StableDiffusionPipeline"] = None, 
        num_inference_step: Optional[int] = None, 

        # ---------= [Param] =---------
        prompt_emb_list: List[torch.Tensor] = None, 
        param_dict: Optional[Dict] = None, 
        num_sample_per_prompt: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 
        reward_dtype: Optional[str] = "float32", 
        offload_to_cpu: Optional[bool] = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size: Optional[int] = 1, 
        cal_intermediate_reward_batch_size: Optional[int] = 1, 
        cal_final_reward_batch_size: Optional[int] = 1, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy: str = "disabled", 
        # potential_exp_growing: bool = False, 
        # potential_exp_base: float = 1.0, 
        cal_intermediate_reward_policy: Optional[str] = "sequential", 

        device: Optional[str] = "cpu", 

        vae_decode_batch_size: Optional[int] = 10, 

        **arg_dict
    ):
        super().__init__(
            # ---------= [Pipeline] =---------
            pipeline = pipeline, 
            num_inference_step = num_inference_step, 

            # ---------= [Param] =---------
            param_dict = param_dict, 
            prompt_emb_list = prompt_emb_list, 
            num_sample_per_prompt = num_sample_per_prompt, 

            # ---------= [Reward] =---------
            reward_shape = reward_shape, 
            reward_dtype = reward_dtype, 
            offload_to_cpu = offload_to_cpu, 

            # ---------= [Parallel] =---------
            cal_dynamics_batch_size = cal_dynamics_batch_size, 
            cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
            cal_final_reward_batch_size = cal_final_reward_batch_size, 

            # ---------= [Reward Shaping] =---------
            reward_shaping_policy = reward_shaping_policy, 
            # potential_exp_growing = potential_exp_growing, 
            # potential_exp_base = potential_exp_base, 
            cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

            device = device, 

            vae_decode_batch_size = vae_decode_batch_size
        )

        self._cfg_yaml_path = Path("./config/reward_model/image_reward.yaml")
        self._model_dict = load_yaml(self._cfg_yaml_path)

        # ---------= [Prepare Everything] =---------
        self._model = None

        self._prepare_everything()

        # `__init__()` done
        pass


    def _prepare_everything(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """

        model_ckpt_path = self._model_dict["image_reward"]["model_ckpt_path"]

        self._model = ImageReward.load(model_ckpt_path)

        self.clip_range = self._model_dict["image_reward"]["clip_range"]

        self.norm_constant = self._model_dict["image_reward"]["norm_constant"]
        self.bias_constant = self._model_dict["image_reward"]["bias_constant"]
        
        # `_prepare_everything()` done
        pass


    @torch.no_grad()
    def batch_cal_image_reward(
        self, 

        img_pil_list: List[Image.Image], 
        prompt: str
    ) -> List[float]:
        """
        Func:
            Compute the ImageRewards for RGB images `img_pil_list` with text prompt `prompt`. 

        Ret:
            `image_reward_list` (`List[float]`): The derived ImageRewards. 
        """

        ranking_list, image_reward_list = self._model.inference_rank(
            prompt, 
            img_pil_list
        )

        # ---------= [Clean Up] =---------
        del ranking_list
        gc.collect()

        # `batch_cal_image_reward()` done
        return image_reward_list


    @torch.no_grad()
    def cal_final_reward_implement(
        self, 

        sample_idx_list: List[int], 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """

        assert (img_pil_list is not None)

        # ---------= [Prepare Prompt List] =---------
        num_img = len(img_pil_list)

        if isinstance(prompt_list, str):
            prompt_list = [prompt_list]

        num_prompt = len(prompt_list)
        if num_prompt == 1:
            prompt_list = prompt_list * num_img
        elif num_prompt != num_img:
            raise ValueError(
                f"The length of `prompt_list` does not match the length of `img_pil_list`, "
                f"got `{num_prompt}` and `{num_img}`. "
            )

        # ---------= [Compute Final Reward] =---------
        final_reward_list = []

        l = 0

        while l < num_img:
            prompt = prompt_list[l]

            r = l
            while (r + 1 < num_img) and (prompt_list[r + 1] == prompt):
                r += 1
            
            tmp_img_pil_list = img_pil_list[l: (r + 1)]

            tmp_reward_list = self.batch_cal_image_reward(
                img_pil_list = tmp_img_pil_list, 
                prompt = prompt
            )
            
            if not isinstance(tmp_reward_list, list):
                tmp_reward_list = [tmp_reward_list]
                
            final_reward_list += tmp_reward_list

            l = r + 1

            # goto `for l`
            pass
        
        # final_reward_list.shape = (num_img, *reward_shape)
        final_reward_list = torch.tensor(
            final_reward_list, 

            dtype = self._reward_dtype, 
            device = self._device
        )

        # `cal_final_reward_implement()` done
        return final_reward_list
