from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

import gc

from pathlib import Path

from PIL import Image

import numpy as np

import torch

from util.yaml_util import load_yaml
from util.pipeline_util import img_latent_to_pil

from .reward_model import RewardModel


class ColorChannel_RewardModel(RewardModel):
    def __init__(
        self, 

        # ---------= [Pipeline] =---------
        pipeline: Optional["StableDiffusionPipeline"] = None, 
        num_inference_step: Optional[int] = None, 

        # ---------= [Param] =---------
        prompt_emb_list: List[torch.Tensor] = None, 
        param_dict: Optional[Dict] = None, 
        num_sample_per_prompt: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 
        reward_dtype: Optional[str] = "float32", 
        offload_to_cpu: Optional[bool] = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size: Optional[int] = 1, 
        cal_intermediate_reward_batch_size: Optional[int] = 1, 
        cal_final_reward_batch_size: Optional[int] = 1, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy: str = "disabled", 
        # potential_exp_growing: bool = False, 
        # potential_exp_base: float = 1.0, 
        cal_intermediate_reward_policy: Optional[str] = "sequential", 

        device: Optional[str] = "cpu", 

        vae_decode_batch_size: Optional[int] = 10, 

        **arg_dict
    ):
        super().__init__(
            # ---------= [Pipeline] =---------
            pipeline = pipeline, 
            num_inference_step = num_inference_step, 

            # ---------= [Param] =---------
            param_dict = param_dict, 
            prompt_emb_list = prompt_emb_list, 
            num_sample_per_prompt = num_sample_per_prompt, 

            # ---------= [Reward] =---------
            reward_shape = reward_shape, 
            reward_dtype = reward_dtype, 
            offload_to_cpu = offload_to_cpu, 

            # ---------= [Parallel] =---------
            cal_dynamics_batch_size = cal_dynamics_batch_size, 
            cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
            cal_final_reward_batch_size = cal_final_reward_batch_size, 

            # ---------= [Reward Shaping] =---------
            reward_shaping_policy = reward_shaping_policy, 
            # potential_exp_growing = potential_exp_growing, 
            # potential_exp_base = potential_exp_base, 
            cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

            device = device, 

            vae_decode_batch_size = vae_decode_batch_size
        )

        self._cfg_yaml_path = Path("./config/reward_model/color_channel_reward.yaml")
        self._model_dict = load_yaml(self._cfg_yaml_path)

        # ---------= [Prepare Everything] =---------
        self.target_channel_idx = None
        
        self._prepare_everything()

        # `__init__()` done
        pass


    def _prepare_everything(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """

        target_channel_idx = self._model_dict["color_channel_reward"]["target_channel_idx"]

        if target_channel_idx not in [0, 1, 2]:
            raise ValueError(
                f"Unsupported `target_channel_idx`, got `{target_channel_idx}`. "
            )

        self.target_channel_idx = target_channel_idx

        self.clip_range = self._model_dict["color_channel_reward"]["clip_range"]

        self.norm_constant = self._model_dict["color_channel_reward"]["norm_constant"]
        self.bias_constant = self._model_dict["color_channel_reward"]["bias_constant"]
        
        # `_prepare_everything()` done
        pass


    @torch.no_grad()
    def cal_color_channel_reward(
        self, 

        img_pil: Image.Image
    ) -> float:
        """
        Func:
            Compute the color channel reward for an RGB image `img_pil`. 

        Ret:
            `color_channel_reward` (`float`): The derived color channel reward. 
        """

        # norm to [0, 1]
        img_np = np.array(img_pil) / 255.0

        if img_np.shape[2] != 3:
            raise ValueError(
                f"Only support RGB images. "
            )
        
        c_1, c_2 = [
            c_i \
                for c_i in range(3) \
                    if c_i != self.target_channel_idx
        ]
        
        # color_channel_reward = np.sum(
        #     img_np[:, :, self.target_channel_idx] \
        #         - img_np[:, :, c_1] - img_np[:, :, c_2]
        # ).item()

        color_channel_reward = np.sum(
            (img_np[:, :, self.target_channel_idx] ** 2) \
                - (img_np[:, :, c_1] ** 2) \
                - (img_np[:, :, c_2] ** 2)
        ).item()

        # ---------= [Scaling] =---------
        height, width, _ = img_np.shape
        color_channel_reward /= (height * width)

        # ---------= [Clean Up] =---------
        del img_np
        gc.collect()

        # `cal_color_channel_reward()` done
        return color_channel_reward


    # TODO: batch cal
    @torch.no_grad()
    def cal_final_reward_implement(
        self, 

        sample_idx_list: List[int], 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """

        assert (img_pil_list is not None)

        # ---------= [Prepare Prompt List] =---------
        num_img = len(img_pil_list)

        if isinstance(prompt_list, str):
            prompt_list = [prompt_list]

        num_prompt = len(prompt_list)
        if num_prompt == 1:
            prompt_list = prompt_list * num_img
        elif num_prompt != num_img:
            raise ValueError(
                f"The length of `prompt_list` does not match the length of `img_pil_list`, "
                f"got `{num_prompt}` and `{num_img}`. "
            )

        # ---------= [Compute Final Reward] =---------
        final_reward_list = []

        l = 0

        while l < num_img:
            prompt = prompt_list[l]

            r = l
            while (r + 1 < num_img) and (prompt_list[r + 1] == prompt):
                r += 1
            
            tmp_img_pil_list = img_pil_list[l: (r + 1)]

            tmp_reward_list = [
                self.cal_color_channel_reward(
                    img_pil = img_pil
                ) \
                    for img_pil in tmp_img_pil_list
            ]

            final_reward_list += tmp_reward_list

            l = r + 1

            # goto `for l`
            pass
        
        # final_reward_list.shape = (num_img, *reward_shape)
        final_reward_list = torch.tensor(
            final_reward_list, 

            dtype = self._reward_dtype, 
            device = self._device
        )

        # `cal_final_reward_implement()` done
        return final_reward_list
