from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

import gc

from pathlib import Path

from PIL import Image

import numpy as np

import torch

import cv2

from util.yaml_util import load_yaml
from util.pipeline_util import img_latent_to_pil

from .reward_model import RewardModel


class LaplacianVariance_RewardModel(RewardModel):
    def __init__(
        self, 

        # ---------= [Pipeline] =---------
        pipeline: Optional["StableDiffusionPipeline"] = None, 
        num_inference_step: Optional[int] = None, 

        # ---------= [Param] =---------
        prompt_emb_list: List[torch.Tensor] = None, 
        param_dict: Optional[Dict] = None, 
        num_sample_per_prompt: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 
        reward_dtype: Optional[str] = "float32", 
        offload_to_cpu: Optional[bool] = True, 

        # ---------= [Parallel] =---------
        cal_dynamics_batch_size: Optional[int] = 1, 
        cal_intermediate_reward_batch_size: Optional[int] = 1, 
        cal_final_reward_batch_size: Optional[int] = 1, 

        # ---------= [Reward Shaping] =---------
        reward_shaping_policy: str = "disabled", 
        # potential_exp_growing: bool = False, 
        # potential_exp_base: float = 1.0, 
        cal_intermediate_reward_policy: Optional[str] = "sequential", 

        device: Optional[str] = "cpu", 

        vae_decode_batch_size: Optional[int] = 10, 

        **arg_dict
    ):
        super().__init__(
            # ---------= [Pipeline] =---------
            pipeline = pipeline, 
            num_inference_step = num_inference_step, 

            # ---------= [Param] =---------
            param_dict = param_dict, 
            prompt_emb_list = prompt_emb_list, 
            num_sample_per_prompt = num_sample_per_prompt, 

            # ---------= [Reward] =---------
            reward_shape = reward_shape, 
            reward_dtype = reward_dtype, 
            offload_to_cpu = offload_to_cpu, 

            # ---------= [Parallel] =---------
            cal_dynamics_batch_size = cal_dynamics_batch_size, 
            cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size, 
            cal_final_reward_batch_size = cal_final_reward_batch_size, 

            # ---------= [Reward Shaping] =---------
            reward_shaping_policy = reward_shaping_policy, 
            # potential_exp_growing = potential_exp_growing, 
            # potential_exp_base = potential_exp_base, 
            cal_intermediate_reward_policy = cal_intermediate_reward_policy, 

            device = device, 

            vae_decode_batch_size = vae_decode_batch_size
        )

        self._cfg_yaml_path = Path("./config/reward_model/laplacian_var_reward.yaml")
        self._model_dict = load_yaml(self._cfg_yaml_path)

        # ---------= [Prepare Everything] =---------
        self.target_channel_idx = None
        
        self._prepare_everything()

        # `__init__()` done
        pass


    def _prepare_everything(
        self
    ):
        """
        Func:
            Prepare parameters needed. 
        """

        self._tsfm_to_gray = self._model_dict["laplacian_var_reward"]["tsfm_to_gray"]
        
        self.clip_range = self._model_dict["laplacian_var_reward"]["clip_range"]

        self.norm_constant = self._model_dict["laplacian_var_reward"]["norm_constant"]
        self.bias_constant = self._model_dict["laplacian_var_reward"]["bias_constant"]
        
        # `_prepare_everything()` done
        pass


    @torch.no_grad()
    def cal_laplacian_var_reward(
        self, 

        img_pil: Image.Image
    ) -> float:
        """
        Func:
            Compute the Laplacian variance reward for an RGB image `img_pil`. 

        Ret:
            `laplacian_var_reward` (`float`): The derived Laplacian variance reward. 
        """

        # norm to [0, 1]
        # img_np = np.array(img_pil) / 255.0

        img_np = np.array(img_pil)

        if img_np.shape[2] != 3:
            raise ValueError(
                f"Only support RGB images. "
            )
        
        img_cv2 = cv2.cvtColor(
            img_np, 
            cv2.COLOR_RGB2BGR
        )

        # TODO: remove
        device_backup = None
        if self._device != "cpu":
            device_backup = self._device
            self._device = "cpu"

        if self._device == "cpu":
            if self._tsfm_to_gray:
                gray_img_cv2 = cv2.cvtColor(
                    img_cv2, 
                    cv2.COLOR_BGR2GRAY
                )

                laplacian = cv2.Laplacian(
                    gray_img_cv2, 
                    cv2.CV_64F
                )
                
                laplacian_var_reward = laplacian.var()

                # ---------= [Clean Up] =---------
                del gray_img_cv2
                del laplacian
            else:
                b_channel, g_channel, r_channel = cv2.split(img_cv2)

                b_laplacian = cv2.Laplacian(
                    b_channel, 
                    cv2.CV_64F
                )
                g_laplacian = cv2.Laplacian(
                    g_channel, 
                    cv2.CV_64F
                )
                r_laplacian = cv2.Laplacian(
                    r_channel, 
                    cv2.CV_64F
                )

                b_variance = b_laplacian.var()
                g_variance = g_laplacian.var()
                r_variance = r_laplacian.var()

                laplacian_var_reward = (b_variance + g_variance + r_variance) / 3

                # ---------= [Clean Up] =---------
                del b_laplacian, g_laplacian, r_laplacian
        else:
            assert cv2.cuda.getCudaEnabledDeviceCount() > 0

            if self._tsfm_to_gray:
                gray_img_cv2 = cv2.cvtColor(
                    img_cv2, 
                    cv2.COLOR_BGR2GRAY
                )

                grey_img_cv2_gpu = cv2.cuda_GpuMat()
                grey_img_cv2_gpu.upload(gray_img_cv2)

                laplacian = cv2.cuda.Laplacian(
                    grey_img_cv2_gpu, 
                    cv2.CV_64F
                ).download()
                
                laplacian_var_reward = laplacian.var()

                # ---------= [Clean Up] =---------
                del gray_img_cv2, grey_img_cv2_gpu
                del laplacian
            else:
                img_cv2_gpu = cv2.cuda_GpuMat()
                img_cv2_gpu.upload(img_cv2)

                b_channel, g_channel, r_channel = cv2.cuda.split(img_cv2_gpu)

                b_laplacian = cv2.cuda.Laplacian(
                    b_channel, 
                    cv2.CV_64F
                ).download()
                g_laplacian = cv2.cuda.Laplacian(
                    g_channel, 
                    cv2.CV_64F
                ).download()
                r_laplacian = cv2.cuda.Laplacian(
                    r_channel, 
                    cv2.CV_64F
                ).download()

                b_variance = b_laplacian.var()
                g_variance = g_laplacian.var()
                r_variance = r_laplacian.var()

                laplacian_var_reward = (b_variance + g_variance + r_variance) / 3

                # ---------= [Clean Up] =---------
                del img_cv2_gpu
                del b_laplacian, g_laplacian, r_laplacian

        # ---------= [Clean Up] =---------
        del img_np, img_cv2
        gc.collect()
        if self._device != "cpu":
            torch.cuda.empty_cache()

        # TODO: remove
        if device_backup is not None:
            self._device = device_backup

        # `cal_laplacian_var_reward()` done
        return laplacian_var_reward


    # TODO: batch cal
    @torch.no_grad()
    def cal_final_reward_implement(
        self, 

        sample_idx_list: List[int], 

        latent_list: Optional[torch.Tensor] = None, 
        img_pil_list: Optional[List[Image.Image]] = None, 

        prompt_list: Union[str, List[str]] = None, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for the clean images. 

        Ret:
            `final_reward_list` (`torch.Tensor`): The list of the final rewards. 
                final_reward_list.shape = (num_img, 1). 
        """

        assert (img_pil_list is not None)

        # ---------= [Prepare Prompt List] =---------
        num_img = len(img_pil_list)

        if isinstance(prompt_list, str):
            prompt_list = [prompt_list]

        num_prompt = len(prompt_list)
        if num_prompt == 1:
            prompt_list = prompt_list * num_img
        elif num_prompt != num_img:
            raise ValueError(
                f"The length of `prompt_list` does not match the length of `img_pil_list`, "
                f"got `{num_prompt}` and `{num_img}`. "
            )

        # ---------= [Compute Final Reward] =---------
        final_reward_list = []

        l = 0

        while l < num_img:
            prompt = prompt_list[l]

            r = l
            while (r + 1 < num_img) and (prompt_list[r + 1] == prompt):
                r += 1
            
            tmp_img_pil_list = img_pil_list[l: (r + 1)]

            tmp_reward_list = [
                self.cal_laplacian_var_reward(
                    img_pil = img_pil
                ) \
                    for img_pil in tmp_img_pil_list
            ]

            final_reward_list += tmp_reward_list

            l = r + 1

            # goto `for l`
            pass
        
        # final_reward_list.shape = (num_img, *reward_shape)
        final_reward_list = torch.tensor(
            final_reward_list, 

            dtype = self._reward_dtype, 
            device = self._device
        )
        
        final_reward_list = final_reward_list.reshape(
            (num_img, *self._reward_shape)
        )

        # `cal_final_reward_implement()` done
        return final_reward_list
