from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

import gc

from pathlib import Path

from PIL import Image

from io import BytesIO

import torch

from util.yaml_util import load_yaml
from util.pipeline_util import img_latent_to_pil

from .reward_model import RewardModel


class Compressibility_RewardModel(RewardModel):
    def __init__(
        self, 

        # ---------= [Pipeline] =---------
        pipeline: Optional["StableDiffusionPipeline"] = None, 
        num_inference_step: Optional[int] = None, 

        # ---------= [Param] =---------
        prompt_emb_list: List[torch.Tensor] = None, 
        param_dict: Optional[Dict] = None, 
        num_sample_per_prompt: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 
        reward_dtype: Optional[str] = "float32", 
        offload_to_cpu: Optional[bool] = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size: Optional[int] = 1, 
        cal_intermediate_reward_batch_size: Optional[int] = 1, 
        cal_final_reward_batch_size: Optional[int] = 1, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy: str = "disabled", 
        # potential_exp_growing: bool = False, 
        # potential_exp_base: float = 1.0, 
        cal_intermediate_reward_policy: Optional[str] = "sequential", 

        device: Optional[str] = "cpu", 

        vae_decode_batch_size: Optional[int] = 10, 

        **arg_dict
    ):
        super().__init__(
            # ---------= [Pipeline] =---------
            pipeline = pipeline, 
            num_inference_step = num_inference_step, 

            # ---------= [Param] =---------
            param_dict = param_dict, 
            prompt_emb_list = prompt_emb_list, 
            num_sample_per_prompt = num_sample_per_prompt, 

            # ---------= [Reward] =---------
            reward_shape = reward_shape, 
            reward_dtype = reward_dtype, 
            offload_to_cpu = offload_to_cpu, 

            # ---------= [Parallel] =---------
            cal_dynamics_batch_size = cal_dynamics_batch_size, 
            cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
            cal_final_reward_batch_size = cal_final_reward_batch_size, 

            # ---------= [Reward Shaping] =---------
            reward_shaping_policy = reward_shaping_policy, 
            # potential_exp_growing = potential_exp_growing, 
            # potential_exp_base = potential_exp_base, 
            cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

            device = device, 

            vae_decode_batch_size = vae_decode_batch_size
        )

        inv = arg_dict.get("inv", None)
        if inv is None:
            raise ValueError(
                f"`inv` must be provided when using compressibility / incompressibility reward model. "
            )
        
        self.inv = inv

        cfg_yaml_path = "./config/reward_model/compressibility_reward.yaml" if self.inv \
            else "./config/reward_model/incompressibility_reward.yaml"
        self._cfg_yaml_path = Path(cfg_yaml_path)
        self._model_dict = load_yaml(self._cfg_yaml_path)

        # ---------= [Prepare Everything] =---------
        self.compression_algorithm = None

        # used for JPEG
        self._quality = None
        self._optimize = None

        # avoid numerical instability
        self._byte_to_MB = float(1024 ** 2)
        
        self._prepare_everything()

        # `__init__()` done
        pass


    def _prepare_everything(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """

        compression_algorithm = self._model_dict["compressibility_reward"]["compression_algorithm"]
        self.compression_algorithm = compression_algorithm

        if self.compression_algorithm == "JPEG":
            quality = self._model_dict["compressibility_reward"]["quality"]
            optimize = self._model_dict["compressibility_reward"]["optimize"]

            self._quality = quality
            self._optimize = optimize
        
        else:
            raise ValueError(
                f"Unsupported `compression_algorithm`, got `{self.compression_algorithm}`. "
            )

        self.clip_range = self._model_dict["compressibility_reward"]["clip_range"]

        self.norm_constant = self._model_dict["compressibility_reward"]["norm_constant"]
        self.bias_constant = self._model_dict["compressibility_reward"]["bias_constant"]
        
        # `_prepare_everything()` done
        pass


    @torch.no_grad()
    def cal_compressibility_reward(
        self, 

        img_pil: Image.Image
    ) -> float:
        """
        Func:
            Compute the compressibility reward for an RGB image `img_pil`. 

        Ret:
            `compressibility_reward` (`float`): The derived compressibility reward. 
        """
        
        buffer = BytesIO()

        if self.compression_algorithm == "JPEG":
            img_pil.save(
                buffer, 

                format = "JPEG", 
                quality = self._quality, 
                optimize = self._optimize
            )

        compressibility_reward = buffer.tell()

        if self.inv:
            compressibility_reward = -compressibility_reward

        # ---------= [Clean Up] =---------
        del buffer
        gc.collect()

        # `cal_compressibility_reward()` done
        return compressibility_reward


    # TODO: batch cal
    @torch.no_grad()
    def cal_final_reward_implement(
        self, 

        sample_idx_list: List[int], 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """

        assert (img_pil_list is not None)

        # ---------= [Prepare Prompt List] =---------
        num_img = len(img_pil_list)

        if isinstance(prompt_list, str):
            prompt_list = [prompt_list]

        num_prompt = len(prompt_list)
        if num_prompt == 1:
            prompt_list = prompt_list * num_img
        elif num_prompt != num_img:
            raise ValueError(
                f"The length of `prompt_list` does not match the length of `img_pil_list`, "
                f"got `{num_prompt}` and `{num_img}`. "
            )

        # ---------= [Compute Final Reward] =---------
        final_reward_list = []

        l = 0

        while l < num_img:
            prompt = prompt_list[l]

            r = l
            while (r + 1 < num_img) and (prompt_list[r + 1] == prompt):
                r += 1
            
            tmp_img_pil_list = img_pil_list[l: (r + 1)]

            tmp_reward_list = [
                self.cal_compressibility_reward(
                    img_pil = img_pil
                ) / self._byte_to_MB \
                    for img_pil in tmp_img_pil_list
            ]

            final_reward_list += tmp_reward_list

            l = r + 1

            # goto `for l`
            pass
        
        # final_reward_list.shape = (num_img, *reward_shape)
        final_reward_list = torch.tensor(
            final_reward_list, 

            dtype = self._reward_dtype, 
            device = self._device
        )

        # `cal_final_reward_implement()` done
        return final_reward_list
