from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

import gc

from pathlib import Path

from PIL import Image

import numpy as np

import torch
import torch.nn.functional as F
import torchvision

import open_clip

from util.yaml_util import load_yaml
from util.pipeline_util import img_latent_to_pil

from .reward_model import RewardModel


class CLIPScore_RewardModel(RewardModel):
    def __init__(
        self, 

        # ---------= [Pipeline] =---------
        pipeline: Optional["StableDiffusionPipeline"] = None, 
        num_inference_step: Optional[int] = None, 

        # ---------= [Param] =---------
        prompt_emb_list: List[torch.Tensor] = None, 
        param_dict: Optional[Dict] = None, 
        num_sample_per_prompt: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 
        reward_dtype: Optional[str] = "float32", 
        offload_to_cpu: Optional[bool] = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size: Optional[int] = 1, 
        cal_intermediate_reward_batch_size: Optional[int] = 1, 
        cal_final_reward_batch_size: Optional[int] = 1, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy: str = "disabled", 
        # potential_exp_growing: bool = False, 
        # potential_exp_base: float = 1.0, 
        cal_intermediate_reward_policy: Optional[str] = "sequential", 

        device: Optional[str] = "cpu", 

        vae_decode_batch_size: Optional[int] = 10
    ):
        super().__init__(
            # ---------= [Pipeline] =---------
            pipeline = pipeline, 
            num_inference_step = num_inference_step, 

            # ---------= [Param] =---------
            param_dict = param_dict, 
            prompt_emb_list = prompt_emb_list, 
            num_sample_per_prompt = num_sample_per_prompt, 

            # ---------= [Reward] =---------
            reward_shape = reward_shape, 
            reward_dtype = reward_dtype, 
            offload_to_cpu = offload_to_cpu, 

            # ---------= [Parallel] =---------
            cal_dynamics_batch_size = cal_dynamics_batch_size, 
            cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
            cal_final_reward_batch_size = cal_final_reward_batch_size, 

            # ---------= [Reward Shaping] =---------
            reward_shaping_policy = reward_shaping_policy, 
            # potential_exp_growing = potential_exp_growing, 
            # potential_exp_base = potential_exp_base, 
            cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

            device = device, 

            vae_decode_batch_size = vae_decode_batch_size
        )

        self._cfg_yaml_path = Path("./config/reward_model/clip_score.yaml")
        self._model_dict = load_yaml(self._cfg_yaml_path)

        # ---------= [Prepare Everything] =---------
        self._clip_model = None
        self._preprocess = None

        # self._tfms = None

        self._prompt_emb_dict = {}
        
        self._prepare_everything()

        # `__init__()` done
        pass


    def _prepare_everything(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """

        open_clip_model_ckpt_path = self._model_dict["clip_score"]["open_clip_model_ckpt_path"]

        clip_model, _, preprocess = open_clip.create_model_and_transforms(
            model_name = "ViT-H-14", 
            pretrained = open_clip_model_ckpt_path, 
            
            device = self._device
        )
        self._clip_model = clip_model
        self._preprocess = preprocess

        # self._tfms = torchvision.transforms.Compose(
        #     [
        #         torchvision.transforms.Resize(224), 
        #         torchvision.transforms.CenterCrop(224), 
        #         torchvision.transforms.ToTensor(), 
        #         torchvision.transforms.Normalize(
        #             mean = (0.48145466, 0.4578275, 0.40821073), 
        #             std = (0.26862954, 0.26130258, 0.27577711)
        #         )
        #     ]
        # )

        self.clip_range = self._model_dict["clip_score"]["clip_range"]

        self.norm_constant = self._model_dict["clip_score"]["norm_constant"]
        self.bias_constant = self._model_dict["clip_score"]["bias_constant"]
        
        # `_prepare_everything()` done
        pass


    @torch.no_grad()
    def cal_img_emb_list(
        self, 

        img_pil_list: List[Image.Image]
    ) -> torch.Tensor:
        img_tensor_list = torch.stack(
            [
                self._preprocess(img_pil) \
                    for img_pil in img_pil_list
            ]
        ).to(self._device)

        with torch.no_grad():
            img_emb_list = self._clip_model.encode_image(img_tensor_list)


        # ---------= [Clean Up] =---------
        del img_tensor_list
        gc.collect()
        torch.cuda.empty_cache()

        # `cal_img_emb_list()` done
        return img_emb_list


    @torch.no_grad()
    def cal_text_emb_list(
        self, 

        prompt_list: List[str]
    ) -> torch.Tensor:
        """
        Func: 
            Compute text embedding for prompts in `prompt_list`. 

        Ret:
            `text_emb_list` (`torch.Tensor`): The list of text embedding for prompts in `prompt_list`. 
        """

        if not isinstance(prompt_list, list):
            prompt_list = [prompt_list]

        num_prompt = len(prompt_list)

        text_emb_list = [None] * num_prompt
        need_cal_idx_list = []

        for i, prompt in enumerate(prompt_list):
            if prompt in self._prompt_emb_dict:
                text_emb_list[i] = self._prompt_emb_dict[prompt]
            else:
                need_cal_idx_list.append(i)

            # goto `for i, prompt`
            pass
    
        if len(need_cal_idx_list) <= 0:
            return text_emb_list

        need_cal_prompt_list = [
            prompt_list[need_cal_idx] \
                for need_cal_idx in need_cal_idx_list
        ]

        need_cal_text_token_list = open_clip.tokenize(need_cal_prompt_list) \
            .to(self._device)
        
        with torch.no_grad():
            need_cal_text_emb_list = self._clip_model.encode_text(need_cal_text_token_list)

        for i, (need_cal_idx, need_cal_text_emb) in enumerate(
            zip(
                need_cal_idx_list, 
                need_cal_text_emb_list
            )
        ):
            text_emb_list[need_cal_idx] = need_cal_text_emb

            # goto `for i, (need_cal_idx, need_cal_text_emb)`
            pass

        text_emb_list = torch.stack(text_emb_list)

        # ---------= [Clean Up] =---------
        del need_cal_prompt_list
        del need_cal_text_token_list
        del need_cal_text_emb_list
        gc.collect()
        torch.cuda.empty_cache()

        # `cal_text_emb_list()` done
        return text_emb_list


    @torch.no_grad()
    def cal_clip_score(
        self, 

        img_pil_list: List[Image.Image], 
        prompt_list: List[str]
    ) -> float:
        """
        Func:
            Compute the CLIP scores for a list of RGB images `img_pil_list`. 

        Ret:
            `clip_score_list` (`List[float]`): The list of the derived CLIP scores. 
        """

        if not isinstance(img_pil_list, list):
            img_pil_list = [img_pil_list]

        if not isinstance(prompt_list, list):
            prompt_list = [prompt_list]
            
        img_emb_list = self.cal_img_emb_list(
            img_pil_list = img_pil_list
        )
        text_emb_list = self.cal_text_emb_list(
            prompt_list = prompt_list
        )

        img_emb_normed_list = F.normalize(
            img_emb_list, 
            dim = 1
        )
        text_emb_normed_list = F.normalize(
            text_emb_list, 
            dim = 1
        )

        clip_score_list = F.cosine_similarity(
            img_emb_normed_list, text_emb_normed_list, 
            dim = 1
        )

        # ---------= [Clean Up] =---------
        del img_emb_list, text_emb_list
        del img_emb_normed_list, text_emb_normed_list
        gc.collect()
        torch.cuda.empty_cache()

        # `cal_clip_score()` done
        return clip_score_list


    # TODO: batch cal
    @torch.no_grad()
    def cal_final_reward_implement(
        self, 

        sample_idx_list: List[int], 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """

        # ---------= [Prepare Image PIL List] =---------
        # if (latent_list is not None) and (img_pil_list is not None):
        #     logger(
        #         f"Both `latent_list` and `img_pil_list` are provided, "
        #         f"`img_pil_list` prioritizes. ", 

        #         log_type = "warning"
        #     )

        # if img_pil_list is None:
        #     img_pil_list = img_latent_to_pil(
        #         img_latent_list = latent_list, 
        #         pipeline = self._pipeline, 

        #         batch_size = self._vae_decode_batch_size
        #     )

        assert (img_pil_list is not None)

        # ---------= [Prepare Prompt List] =---------
        num_img = len(img_pil_list)

        if isinstance(prompt_list, str):
            prompt_list = [prompt_list]

        # num_prompt = len(prompt_list)
        # if num_prompt == 1:
        #     prompt_list = prompt_list * num_img
        # elif num_prompt != num_img:
        #     raise ValueError(
        #         f"The length of `prompt_list` does not match the length of `img_pil_list`, "
        #         f"got `{num_prompt}` and `{num_img}`. "
        #     )

        # ---------= [Compute Final Reward] =---------
        final_reward_list = self.cal_clip_score(
            img_pil_list = img_pil_list, 
            prompt_list = prompt_list
        )

        # l = 0

        # while l < num_img:
        #     prompt = prompt_list[l]

        #     r = l
        #     while (r + 1 < num_img) and (prompt_list[r + 1] == prompt):
        #         r += 1
            
        #     tmp_img_pil_list = img_pil_list[l: (r + 1)]

        #     tmp_reward_list = [
        #         self.cal_clip_score(
        #             img_pil = img_pil
        #         ) \
        #             for img_pil in tmp_img_pil_list
        #     ]

        #     final_reward_list += tmp_reward_list

        #     l = r + 1

        #     # goto `for l`
        #     pass
        
        if not isinstance(final_reward_list, torch.Tensor):
            final_reward_list = torch.tensor(
                final_reward_list, 

                dtype = self._reward_dtype, 
                device = self._device
            )

        # # final_reward_list.shape = (num_img, *reward_shape)
        # final_reward_list = final_reward_list.reshape(
        #     (num_img, *self._reward_shape)
        # )

        # # ---------= [Scaling] =---------
        # final_reward_list /= self.norm_constant
        # final_reward_list += self.bias_constant

        # # ---------= [Update NFE] =---------
        # for sample_idx in sample_idx_list:
        #     self.nfe_cal_final_reward_list[sample_idx] += 1

        #     # goto `for sample_idx`
        #     pass

        # `cal_final_reward_implement()` done
        return final_reward_list
