from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

import gc

from pathlib import Path

from PIL import Image

from io import BytesIO

import torch

import numpy as np

from third_party import hpsv2

from util.yaml_util import load_yaml
from util.pipeline_util import img_latent_to_pil

from .reward_model import RewardModel


class Compressibility_HPS_v2_RewardModel(RewardModel):
    def __init__(
        self, 

        # ---------= [Pipeline] =---------
        pipeline: Optional["StableDiffusionPipeline"] = None, 
        num_inference_step: Optional[int] = None, 

        # ---------= [Param] =---------
        prompt_emb_list: List[torch.Tensor] = None, 
        param_dict: Optional[Dict] = None, 
        num_sample_per_prompt: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 
        reward_dtype: Optional[str] = "float32", 
        offload_to_cpu: Optional[bool] = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size: Optional[int] = 1, 
        cal_intermediate_reward_batch_size: Optional[int] = 1, 
        cal_final_reward_batch_size: Optional[int] = 1, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy: str = "disabled", 
        # potential_exp_growing: bool = False, 
        # potential_exp_base: float = 1.0, 
        cal_intermediate_reward_policy: Optional[str] = "sequential", 

        device: Optional[str] = "cpu", 

        vae_decode_batch_size: Optional[int] = 10, 

        **arg_dict
    ):
        super().__init__(
            # ---------= [Pipeline] =---------
            pipeline = pipeline, 
            num_inference_step = num_inference_step, 

            # ---------= [Param] =---------
            param_dict = param_dict, 
            prompt_emb_list = prompt_emb_list, 
            num_sample_per_prompt = num_sample_per_prompt, 

            # ---------= [Reward] =---------
            reward_shape = reward_shape, 
            reward_dtype = reward_dtype, 
            offload_to_cpu = offload_to_cpu, 

            # ---------= [Parallel] =---------
            cal_dynamics_batch_size = cal_dynamics_batch_size, 
            cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
            cal_final_reward_batch_size = cal_final_reward_batch_size, 

            # ---------= [Reward Shaping] =---------
            reward_shaping_policy = reward_shaping_policy, 
            # potential_exp_growing = potential_exp_growing, 
            # potential_exp_base = potential_exp_base, 
            cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

            device = device, 

            vae_decode_batch_size = vae_decode_batch_size
        )

        inv = arg_dict.get("inv", None)
        if inv is None:
            raise ValueError(
                f"`inv` must be provided when using compressibility / incompressibility reward model. "
            )
        
        self.inv = inv

        cfg_yaml_path = "./config/reward_model/compressibility_hps_v2.yaml" if self.inv \
            else "./config/reward_model/incompressibility_hps_v2.yaml"
        self._cfg_yaml_path = Path(cfg_yaml_path)
        self._model_dict = load_yaml(self._cfg_yaml_path)

        # ---------= [Prepare Compressibility Reward] =---------
        self.compression_algorithm = None

        # used for JPEG
        self._quality = None
        self._optimize = None

        # avoid numerical instability
        self._byte_to_MB = float(1024 ** 2)
        
        self._prepare_compressibility_reward()

        # ---------= [Prepare HPS v2] =---------
        self._hps_model_ckpt_path = None
        self._vit_model_ckpt_path = None
        
        self._prepare_hps_v2()

        # ---------= [Prepare Everything] =---------
        self._lam = None

        self._prepare_everything()

        # `__init__()` done
        pass


    def _prepare_compressibility_reward(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """

        compression_algorithm = self._model_dict["compressibility_hps_v2"]["compressibility_reward"]["compression_algorithm"]
        self.compression_algorithm = compression_algorithm

        if self.compression_algorithm == "JPEG":
            quality = self._model_dict["compressibility_hps_v2"]["compressibility_reward"]["quality"]
            optimize = self._model_dict["compressibility_hps_v2"]["compressibility_reward"]["optimize"]

            self._quality = quality
            self._optimize = optimize
        
        else:
            raise ValueError(
                f"Unsupported `compression_algorithm`, got `{self.compression_algorithm}`. "
            )

        self._compressibility_clip_range = self._model_dict["compressibility_hps_v2"]["compressibility_reward"]["clip_range"]

        self._compressibility_norm_constant = self._model_dict["compressibility_hps_v2"]["compressibility_reward"]["norm_constant"]
        self._compressibility_bias_constant = self._model_dict["compressibility_hps_v2"]["compressibility_reward"]["bias_constant"]
        
        # `_prepare_compressibility_reward()` done
        pass


    def _prepare_hps_v2(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """

        hps_model_ckpt_path = self._model_dict["compressibility_hps_v2"]["hps_v2"]["hps_model_ckpt_path"]
        vit_model_ckpt_path = self._model_dict["compressibility_hps_v2"]["hps_v2"]["vit_model_ckpt_path"]

        self._hps_model_ckpt_path = hps_model_ckpt_path
        self._vit_model_ckpt_path = vit_model_ckpt_path

        self._hps_v2_clip_range = self._model_dict["compressibility_hps_v2"]["hps_v2"]["clip_range"]

        self._hps_v2_norm_constant = self._model_dict["compressibility_hps_v2"]["hps_v2"]["norm_constant"]
        self._hps_v2_bias_constant = self._model_dict["compressibility_hps_v2"]["hps_v2"]["bias_constant"]
        
        # `_prepare_hps_v2()` done
        pass


    def _prepare_everything(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """
        
        self._lam = self._model_dict["compressibility_hps_v2"]["lam"]

        self.clip_range = self._model_dict["compressibility_hps_v2"]["clip_range"]

        self.norm_constant = self._model_dict["compressibility_hps_v2"]["norm_constant"]
        self.bias_constant = self._model_dict["compressibility_hps_v2"]["bias_constant"]
        
        # `_prepare_everything()` done
        pass


    @torch.no_grad()
    def cal_compressibility_reward(
        self, 

        img_pil: Image.Image
    ) -> float:
        """
        Func:
            Compute the compressibility reward for an RGB image `img_pil`. 

        Ret:
            `compressibility_reward` (`float`): The derived compressibility reward. 
        """
        
        buffer = BytesIO()

        if self.compression_algorithm == "JPEG":
            img_pil.save(
                buffer, 

                format = "JPEG", 
                quality = self._quality, 
                optimize = self._optimize
            )

        compressibility_reward = buffer.tell()

        if self.inv:
            compressibility_reward = -compressibility_reward

        # ---------= [Clean Up] =---------
        del buffer
        gc.collect()

        # `cal_compressibility_reward()` done
        return compressibility_reward


    @torch.no_grad()
    def batch_cal_hps_v2(
        self, 

        img_pil_list: Union[Image.Image, List[Image.Image]], 
        prompt_list: Union[str, List[str]]
    ) -> List[float]:
        """
        Func:
            Compute the Human Preference Score v2 (HPS v2) for a list of RGB images `img_pil_list`. 

        Ret:
            `hps_v2_list` (float): The list of the derived HPS v2 scores. 
        """

        if not isinstance(img_pil_list, list):
            img_pil_list = [img_pil_list]

        hps_v2_list = hpsv2.score(
            img_pil_list, 
            prompt = prompt_list, 

            hps_model_ckpt_path = self._hps_model_ckpt_path, 
            ViT_model_ckpt_path = self._vit_model_ckpt_path, 
            hps_version = "v2.1"
        )

        # ---------= [Scaling] =---------
        hps_v2_list = [
            score * 100.0 \
                for score in hps_v2_list
        ]

        # `cal_hps_v2()` done
        return hps_v2_list


    # TODO: batch cal
    @torch.no_grad()
    def cal_final_reward_implement(
        self, 

        sample_idx_list: List[int], 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """

        assert (img_pil_list is not None)

        # ---------= [Prepare Prompt List] =---------
        num_img = len(img_pil_list)

        if isinstance(prompt_list, str):
            prompt_list = [prompt_list]

        num_prompt = len(prompt_list)
        if num_prompt == 1:
            prompt_list = prompt_list * num_img
        elif num_prompt != num_img:
            raise ValueError(
                f"The length of `prompt_list` does not match the length of `img_pil_list`, "
                f"got `{num_prompt}` and `{num_img}`. "
            )

        # ---------= [Compute Final Reward] =---------
        final_reward_list = []

        l = 0

        while l < num_img:
            prompt = prompt_list[l]

            r = l
            while (r + 1 < num_img) and (prompt_list[r + 1] == prompt):
                r += 1
            
            tmp_img_pil_list = img_pil_list[l: (r + 1)]

            # ---------= [Cal Compressibility Reward] =---------
            compressibility_reward_list = [
                self.cal_compressibility_reward(
                    img_pil = img_pil
                ) / self._byte_to_MB \
                    for img_pil in tmp_img_pil_list
            ]
            compressibility_reward_list = np.array(compressibility_reward_list)

            self._compressibility_bias_constant\

            compressibility_clip_range = self._compressibility_clip_range
            if compressibility_clip_range is not None:
                compressibility_reward_list = np.clip(
                    compressibility_reward_list, 
                    compressibility_clip_range[0], compressibility_clip_range[1]
                )

            compressibility_reward_list \
                = (compressibility_reward_list + self._compressibility_bias_constant) / self._compressibility_norm_constant

            # ---------= [Cal HPS v2 Reward] =---------
            length = r - l + 1
            batch_size = self._cal_final_reward_batch_size

            hps_v2_list = []

            for i in range(0, length, batch_size):
                batch_tmp_pil_list = tmp_img_pil_list[i: min(i + batch_size, length)]
                
                batch_hps_v2_list = self.batch_cal_hps_v2(
                    img_pil_list = batch_tmp_pil_list, 
                    prompt_list = prompt
                )

                hps_v2_list += batch_hps_v2_list

                # ---------= [Clean Up] =---------
                del batch_tmp_pil_list
                gc.collect()
                torch.cuda.empty_cache()

                # goto `for i`
                pass

            hps_v2_list = np.array(hps_v2_list)

            hps_v2_clip_range = self._hps_v2_clip_range
            if hps_v2_clip_range is not None:
                hps_v2_list = np.clip(
                    hps_v2_list, 
                    hps_v2_clip_range[0], hps_v2_clip_range[1]
                )

            hps_v2_list \
                = (hps_v2_list + self._hps_v2_bias_constant) / self._hps_v2_norm_constant

            tmp_reward_list = compressibility_reward_list + self._lam * hps_v2_list

            # dbg
            # print(f"    compressibility_reward_list: {compressibility_reward_list}")
            # print(f"    self._lam * hps_v2_list: {self._lam * hps_v2_list}")
            # print(f"    tmp_reward_list: {tmp_reward_list}")
            # print()

            tmp_reward_list = tmp_reward_list.tolist()

            final_reward_list += tmp_reward_list

            l = r + 1

            # goto `for l`
            pass
        
        # final_reward_list.shape = (num_img, *reward_shape)
        final_reward_list = torch.tensor(
            final_reward_list, 

            dtype = self._reward_dtype, 
            device = self._device
        )

        # `cal_final_reward_implement()` done
        return final_reward_list
