from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

import gc

from pathlib import Path

from PIL import Image

import numpy as np

import torch

from transformers import AutoProcessor, AutoModel

from util.yaml_util import load_yaml
from util.pipeline_util import img_latent_to_pil

from .reward_model import RewardModel


class PickScore_RewardModel(RewardModel):
    def __init__(
        self, 

        # ---------= [Pipeline] =---------
        pipeline: Optional["StableDiffusionPipeline"] = None, 
        num_inference_step: Optional[int] = None, 

        # ---------= [Param] =---------
        prompt_emb_list: List[torch.Tensor] = None, 
        param_dict: Optional[Dict] = None, 
        num_sample_per_prompt: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 
        reward_dtype: Optional[str] = "float32", 
        offload_to_cpu: Optional[bool] = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size: Optional[int] = 1, 
        cal_intermediate_reward_batch_size: Optional[int] = 1, 
        cal_final_reward_batch_size: Optional[int] = 1, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy: str = "disabled", 
        # potential_exp_growing: bool = False, 
        # potential_exp_base: float = 1.0, 
        cal_intermediate_reward_policy: Optional[str] = "sequential", 

        device: Optional[str] = "cpu", 

        vae_decode_batch_size: Optional[int] = 10, 

        **arg_dict
    ):
        super().__init__(
            # ---------= [Pipeline] =---------
            pipeline = pipeline, 
            num_inference_step = num_inference_step, 

            # ---------= [Param] =---------
            param_dict = param_dict, 
            prompt_emb_list = prompt_emb_list, 
            num_sample_per_prompt = num_sample_per_prompt, 

            # ---------= [Reward] =---------
            reward_shape = reward_shape, 
            reward_dtype = reward_dtype, 
            offload_to_cpu = offload_to_cpu, 

            # ---------= [Parallel] =---------
            cal_dynamics_batch_size = cal_dynamics_batch_size, 
            cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
            cal_final_reward_batch_size = cal_final_reward_batch_size, 

            # ---------= [Reward Shaping] =---------
            reward_shaping_policy = reward_shaping_policy, 
            # potential_exp_growing = potential_exp_growing, 
            # potential_exp_base = potential_exp_base, 
            cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

            device = device, 

            vae_decode_batch_size = vae_decode_batch_size
        )

        self._cfg_yaml_path = Path("./config/reward_model/pick_score.yaml")
        self._model_dict = load_yaml(self._cfg_yaml_path)

        # ---------= [Prepare Everything] =---------
        self._processor = None
        self._model = None

        self._prepare_everything()

        # `__init__()` done
        pass


    def _prepare_everything(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """

        open_clip_model_root_path = self._model_dict["pick_score"]["open_clip_model_root_path"]
        pick_score_model_root_path = self._model_dict["pick_score"]["pick_score_model_root_path"]

        self._processor = AutoProcessor.from_pretrained(open_clip_model_root_path)
        self._model = AutoModel.from_pretrained(pick_score_model_root_path) \
            .eval() \
            .to(self._device)

        self.clip_range = self._model_dict["pick_score"]["clip_range"]

        self.norm_constant = self._model_dict["pick_score"]["norm_constant"]
        self.bias_constant = self._model_dict["pick_score"]["bias_constant"]
        
        # `_prepare_everything()` done
        pass


    @torch.no_grad()
    def cal_pick_score(
        self, 

        img_pil: Image.Image, 
        prompt: str
    ) -> float:
        """
        Func:
            Compute the PickScore for an RGB image `img_pil` with text prompt `prompt`. 

        Ret:
            `pick_score` (`float`): The derived PickScore. 
        """

        img_pil_list = [img_pil]

        img_input_list = self._processor(
            images = img_pil_list, 
            padding = True, 
            truncation = True, 
            max_length = 77, 
            return_tensors = "pt"
        ).to(self._device)

        text_input_list = self._processor(
            text = prompt, 
            padding = True, 
            truncation = True, 
            max_length = 77, 
            return_tensors = "pt"
        ).to(self._device)

        with torch.no_grad():
            img_emb_list = self._model.get_image_features(**img_input_list)
            img_emb_list = img_emb_list / torch.norm(
                img_emb_list, 
                dim = -1, 
                keepdim = True
            )

            text_emb_list = self._model.get_text_features(**text_input_list)
            text_emb_list = text_emb_list / torch.norm(
                text_emb_list, 
                dim = -1, 
                keepdim = True
            )

            pick_score_list = self._model.logit_scale.exp() * (img_emb_list @ text_emb_list.T)

        pick_score = pick_score_list[0]

        # ---------= [Clean Up] =---------
        del img_pil_list
        del img_input_list, text_input_list
        del img_emb_list, text_emb_list
        gc.collect()
        torch.cuda.empty_cache()

        # `cal_pick_score()` done
        return pick_score


    @torch.no_grad()
    def cal_final_reward_implement(
        self, 

        sample_idx_list: List[int], 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """

        assert (img_pil_list is not None)

        # ---------= [Prepare Prompt List] =---------
        num_img = len(img_pil_list)

        if isinstance(prompt_list, str):
            prompt_list = [prompt_list]

        num_prompt = len(prompt_list)
        if num_prompt == 1:
            prompt_list = prompt_list * num_img
        elif num_prompt != num_img:
            raise ValueError(
                f"The length of `prompt_list` does not match the length of `img_pil_list`, "
                f"got `{num_prompt}` and `{num_img}`. "
            )

        # ---------= [Compute Final Reward] =---------
        final_reward_list = []

        l = 0

        while l < num_img:
            prompt = prompt_list[l]

            r = l
            while (r + 1 < num_img) and (prompt_list[r + 1] == prompt):
                r += 1
            
            tmp_img_pil_list = img_pil_list[l: (r + 1)]

            tmp_reward_list = [
                self.cal_pick_score(
                    img_pil = img_pil, 
                    prompt = prompt
                ) \
                    for img_pil in tmp_img_pil_list
            ]

            final_reward_list += tmp_reward_list

            l = r + 1

            # goto `for l`
            pass
        
        # final_reward_list.shape = (num_img, *reward_shape)
        final_reward_list = torch.tensor(
            final_reward_list, 

            dtype = self._reward_dtype, 
            device = self._device
        )

        # `cal_final_reward_implement()` done
        return final_reward_list
