from util.logger import logger

from typing import Optional, Union, Tuple, List, Dict

import numpy as np

import torch

import gc

from util.torch_util import (
    tsfm_to_1d_array, 
    get_latent
)

from OT_MCTS.src.markov_decision_process.mdp import MarkovDecisionProcess


class DiffusionMDP(MarkovDecisionProcess):
    def __init__(
        self, 

        # ---------= [Beam Search] =---------
        is_eps_action: Optional[bool] = False, 
        
        # ---------= [Eta] =---------
        eta_random: Optional[bool] = False, 
        default_eta: Optional[float] = 1.0, 

        state_space: "Space" = None, 
        action_space_eta: "Space" = None, 
        action_space_eps: "Space" = None, 
        time_horizon: int = None, 

        # ---------= [Reward] =---------
        reward_shape: Optional[Tuple] = (1, ), 

        # ---------= [Pipeline] =---------
        pipeline: "StableDiffusionPipeline" = None, 
        eps_list: Union[torch.Tensor, List[torch.Tensor]] = None, 

        cal_intermediate_reward_policy: Optional[bool] = False
    ):
        # ---------= [Beam Search] =---------
        self._is_eps_action = is_eps_action
        
        # ---------= [Eta] =---------
        self._eta_random = eta_random

        if not self._eta_random:
            self._default_eta = default_eta

        super().__init__(
            state_space = state_space, 

            action_space = action_space_eps if self._is_eps_action \
                else action_space_eta, 

            time_horizon = time_horizon, 

            # ---------= [Reward] =---------
            reward_shape = reward_shape
        )

        self._action_space_eta = action_space_eta
        self._action_space_eps = action_space_eps

        if self.ver != "torch":
            raise ValueError(
                f"Only support `self.ver = \"torch\"` in `DiffusionMDP`, "
                f"got `{self.ver}`. "
            )

        # ---------= [Pipeline] =---------
        self.pipeline = pipeline

        # ---------= [Param] =---------
        self.init_latent_list = None

        self.prompt_list = None
        self.optimized_prompt_list = None

        self.prompt_emb_list = None
        self.param_dict = None

        # ---------= [Reward Model] =---------
        self.reward_model = None

        if not self._is_eps_action:
            # ---------= [Eps List] =---------
            if eps_list is None:
                raise ValueError(
                    f"`eps_list` must be provided. "
                )
            elif isinstance(eps_list, list):
                eps_list = torch.stack(eps_list)

            self._eps_list = eps_list

        # ---------= [Intermediate Reward] =---------
        self._cal_intermediate_reward_policy = cal_intermediate_reward_policy

        # `__init__()` done
        pass


    def prepare_everything(
        self, 

        prompt_list: Union[str, List[str]] = None,
        optimized_prompt_list: Union[str, List[str]] = None,
        prompt_2: Union[str, List[str]] = None,
        negative_prompt: Union[str, List[str]] = None,
        negative_prompt_2: Union[str, List[str]] = None,

        height: Optional[int] = None,
        width: Optional[int] = None,

        guidance_scale: float = None, 

        num_sample_per_prompt: int = None, 

        init_latent_list: Union[torch.Tensor, List[torch.Tensor]] = None, 

        inference_step_minus_one: Optional[bool] = False
    ):
        """
        Func:
            Prepare the `self.param_dict` and `self.init_latent_list`. 
        """

        if isinstance(prompt_list, str):
            prompt_list = [prompt_list]

        if init_latent_list is None:
            raise ValueError(
                f"`init_latent_list` must be provided. "
            )

        self.prompt_list = prompt_list
        self.optimized_prompt_list = optimized_prompt_list

        num_prompt = len(self.prompt_list)
        self._num_sample_per_prompt = num_sample_per_prompt
        # num_sample = num_prompt * num_sample_per_prompt

        # init_latent_list.shape = (num_sample, num_channel, height, width)
        init_latent_list = init_latent_list * self.pipeline.scheduler.init_noise_sigma
        self.init_latent_list = init_latent_list

        # # prompt_list = (num_sample, )
        # prompt_list = np.repeat(
        #     np.array(
        #         prompt_list, 
        #         dtype = object
        #     ), 
        #     repeats = num_sample_per_prompt, 

        #     axis = 0
        # ).tolist()

        # # init_latent_list.shape = (num_prompt * num_sample_per_prompt, latent_num_channel, latent_height, latent_width)
        # init_latent_list = torch.cat(
        #     [
        #         init_latent.unsqueeze(0) \
        #             .repeat(num_prompt, 1, 1, 1) \
        #                 for init_latent in init_latent_list
        #     ], 
        #     dim = 0
        # )

        num_inference_step = self.time_horizon
        
        (
            prompt_emb_list, 
            param_dict, 
            latents
        ) = self.pipeline.prepare_everything(
            prompt = self.optimized_prompt_list if (self.optimized_prompt_list is not None) \
                else self.prompt_list, 
                
            prompt_2 = [prompt_2] * num_prompt if (prompt_2 is not None) \
                else None, 
            negative_prompt = [negative_prompt] * num_prompt if (negative_prompt is not None) \
                else None, 
            negative_prompt_2 = [negative_prompt_2] * num_prompt if (negative_prompt_2 is not None) \
                else None, 

            height = height, width = width, 
            guidance_scale = guidance_scale, 
            num_inference_steps = num_inference_step, 

            num_images_per_prompt = 1, 

            # latents = init_latent_list, 

            return_dict = False, 

            inference_step_minus_one = inference_step_minus_one
        )

        self.prompt_emb_list = prompt_emb_list
        self.param_dict = param_dict

        # ---------= [Clean Up] =---------
        del latents
        gc.collect()
        torch.cuda.empty_cache()
        
        # `prepare_everything()` done
        pass


    def set_reward_model(
        self, 

        reward_model: "RewardModel", 

        **arg_dict: Optional[Dict]
    ):
        """
        Func:   
            Set `self.reward_model`. 
        """

        self.reward_model = reward_model

        # ---------= [Prepare Prams for Some Expansion Policy] =---------
        cal_intermediate_reward_policy = self.reward_model.cal_intermediate_reward_policy

        num_look_ahead_step = None
        gamma = None

        if cal_intermediate_reward_policy == "look_ahead":
            _num_look_ahead_step = arg_dict.pop("num_look_ahead_step", None)

            if _num_look_ahead_step is None:
                logger(
                    f"`num_look_ahead_step` is not provided, set default to `2`. ", 
                    log_type = "warning"
                )

                _num_look_ahead_step = 2
            
            num_look_ahead_step = _num_look_ahead_step
        elif cal_intermediate_reward_policy == "discount":
            _gamma = arg_dict.pop("gamma", None)

            if _gamma is None:
                logger(
                    f"`gamma` is not provided, set default to `0.99`. ", 
                    log_type = "warning"
                )

                _gamma = 0.99
            
            gamma = _gamma
        
        self._cal_intermediate_reward_arg_dict = {
            "num_look_ahead_step": num_look_ahead_step, 
            "gamma": gamma
        }
        
        # `set_reward_model()` done
        pass


    def cal_dynamics(
        self, 

        state: torch.Tensor, 
        action: torch.Tensor, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the next state given the current state and action. 

        Ret:
            `next_state` (`torch.Tensor` or `np.ndarray`): The next state. 
        """

        sample_idx = arg_dict.get("sample_idx", None)
        if sample_idx is None:
            raise ValueError(
                f"`sample_idx` must be provided. "
            )

        timestep_idx = arg_dict.get("timestep_idx", None)
        if timestep_idx is None:
            raise ValueError(
                f"`timestep_idx` must be provided. "
            )

        sample_idx_list = [sample_idx]
        state_list = [state]
        action_list = [action]
        timestep_idx_list = [timestep_idx]

        # next_state.shape = (4, latent_height, latent_width)
        next_state = self.batch_cal_dynamics(
            sample_idx_list = sample_idx_list, 
            state_list = state_list, 
            action_list = action_list, 
            timestep_idx_list = timestep_idx_list, 

            **arg_dict
        )[0]

        # ---------= [Clean Up] =---------
        del state_list, action_list
        gc.collect()
        torch.cuda.empty_cache()

        # `cal_dynamics()` done
        return next_state


    def batch_cal_dynamics(
        self, 

        state_list: Union[
            List[torch.Tensor], 
            torch.Tensor
        ], 
        action_list: Union[
            List[torch.Tensor], 
            torch.Tensor
        ], 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Batch compute the next states given the current states and actions. 

        Ret:
            `next_state_list` (`torch.Tensor` or `np.ndarray`): The list of the next states. 
        """

        if isinstance(state_list, list):
            state_list = torch.stack(
                state_list, 
                dim = 0
            )

        if isinstance(action_list, list):
            action_list = torch.stack(
                action_list, 
                dim = 0
            )
        
        num_state = state_list.shape[0]

        sample_idx_list = arg_dict.get("sample_idx_list", None)
        if sample_idx_list is None:
            raise ValueError(
                f"`sample_idx_list` must be provided. "
            )

        timestep_idx_list = arg_dict.get("timestep_idx_list", None)

        if timestep_idx_list is None:
            raise ValueError(
                f"`timestep_idx_list` must be provided. "
            )
        else:
            timestep_idx_list = tsfm_to_1d_array(
                array = timestep_idx_list, 
                target_length = num_state, 

                dtype = torch.int32, 
                device = "cpu"
            )
        
        if self._is_eps_action:
            if self._eta_random:
                eta_list = self._action_space_eta.sample_uniform_element()
            else:
                eta_list = self._default_eta

            # eps_seed_list = [
            #     self._action_space_eps.sample_uniform_element() \
            #         for _ in range(num_state)
            # ]

            eps_seed_list = [
                int(eps_seed) \
                    for eps_seed in action_list
            ]

            eps_list = [
                get_latent(
                    shape = state_list.shape[1: ], 

                    seed = eps_seed, 

                    device = state_list.device, 
                    dtype = state_list.dtype
                ) for eps_seed in eps_seed_list
            ]
        else:
            eta_list = action_list
            eps_list = self._eps_list[timestep_idx_list]

        next_state_list = self.reward_model.batch_cal_dynamics(
            latent_list = state_list, 

            # timestep_idx_list = timestep_idx_list,  # included in `arg_dict`

            eta_list = eta_list, 

            eps_list = eps_list, 

            **arg_dict
        )

        next_state_list = self.state_space.batch_clamp(next_state_list)

        # ---------= [Clean Up] =---------
        if self._is_eps_action:
            if self._eta_random:
                del eta_list
            del eps_seed_list
            del eps_list
            gc.collect()
            torch.cuda.empty_cache()

        # `batch_cal_dynamics()` done
        return next_state_list


    def cal_intermediate_reward(
        self, 

        sample_idx: int, 

        state: torch.Tensor, 
        action: torch.Tensor, 

        **arg_dict: Optional[Dict]
    ) -> Tuple[
        torch.Tensor, 
        torch.Tensor
    ]:
        """
        Func:
            Compute the intermediate reward for a given state and action. 

        Ret:
            `intermediate_reward` (`float`): The intermediate reward. 
            `pseudo_final_latent` (`torch.Tensor`): 
                The predicted final latents used for computing intermediate rewards. 
                    pseudo_final_latent.shape = self._state_shape. 
        """

        if self.reward_model.reward_shaping_policy == "disabled":
            intermediate_reward = torch.zeros(
                self.reward_shape, 

                dtype = self.dtype, 
                device = self.device
            )
             
            return (
                intermediate_reward, 
                None
            )

        sample_idx_list = [sample_idx]

        state_list = [state]
        action_list = [action]

        # intermediate_reward.shape = reward_shape
        (
            intermediate_reward_list, 
            pseudo_final_latent_list
        ) = self.batch_cal_intermediate_reward(
            sample_idx_list = sample_idx_list, 

            state_list = state_list, 
            action_list = action_list, 

            **arg_dict
        )

        intermediate_reward = intermediate_reward_list[0]
        pseudo_final_latent = pseudo_final_latent_list[0]

        # `cal_intermediate_reward()` done
        return (
            intermediate_reward, 
            pseudo_final_latent
        )
    

    def batch_cal_intermediate_reward(
        self, 

        sample_idx_list: List[int], 

        state_list: Union[
            List[torch.Tensor], 
            torch.Tensor
        ], 
        prev_action_list: Union[
            List[torch.Tensor], 
            torch.Tensor
        ], 

        **arg_dict: Optional[Dict]
    ) -> Tuple[
        torch.Tensor, 
        torch.Tensor
    ]:
        """
        Func:
            Compute the intermediate rewards for the given states and actions. 

        Ret:
            `intermediate_reward_list` (`torch.Tensor`): The list of the shaped intermediate rewards. 
                intermediate_reward_list.shape = (num_latent, *self._reward_shape). 
            `pseudo_final_latent_list` (`torch.Tensor`): 
                The list of the predicted final latents used for computing intermediate rewards. 
                    pseudo_final_latent_list.shape = (num_latent, *self._state_shape). 
            # `potential_list` (`torch.Tensor`): The list of the potentials (raw intermediate rewards). 
            #     potential_list.shape = (num_latent, *self._reward_shape). 
        """
        
        if isinstance(state_list, list):
            state_list = torch.stack(
                state_list, 
                dim = 0
            )
        
        if isinstance(prev_action_list, list):
            prev_action_list = torch.stack(
                prev_action_list, 
                dim = 0
            )
        
        num_state = state_list.shape[0]

        if self.reward_model.reward_shaping_policy == "disabled":
            intermediate_reward_list = torch.zeros(
                (num_state, *self.reward_shape), 

                dtype = self.dtype, 
                device = self.device
            )
             
            return (
                intermediate_reward_list, 
                None
            )

        # ---------= [Prepare Timestep Index List] =---------
        timestep_idx_list = arg_dict.get("timestep_idx_list", None)

        if timestep_idx_list is None:
            raise ValueError(
                f"`timestep_idx_list` must be provided. "
            )
        else:
            timestep_idx_list = tsfm_to_1d_array(
                array = timestep_idx_list, 
                target_length = num_state, 

                dtype = torch.int32, 
                device = "cpu"
            )
        
        # ---------= [Prepare Params] =---------
        # node = arg_dict.get("node", None)

        # prev_latent_list = arg_dict.get("prev_latent_list", None)
        # prev_potential_list = arg_dict.get("prev_potential_list", None)

        # sample_idx_list = arg_dict.get("sample_idx_list", None)

        # if node is None:
        #     raise ValueError(
        #         f"`node` must be provided. "
        #     )
        
        # prev_latent_list = None

        # if self.reward_model.reward_shaping_policy == "skipping":    
            # if sample_idx_list is None:
            #     raise ValueError(
            #         f"`sample_idx_list` must be provided when `self.reward_model.use_difference_reward = True`. "
            #     )
            # else:
            #     sample_idx_list = tsfm_to_1d_array(
            #         array = sample_idx_list, 
            #         target_length = num_state, 

            #         dtype = torch.int32, 
            #         device = "cpu"
            #     )

            # prev_latent_list: `z_{t_{i - 1}}`
            # prev_latent_list.shape = (num_sample, 4, latent_height, latent_width)
            
        # ---------= [Prepare Params] =---------
        prompt_list = [
            self.prompt_list[sample_idx // self._num_sample_per_prompt] \
                for sample_idx in sample_idx_list
        ]
        
        # # prev_latent_list = None
        # if self.reward_model.reward_shaping_policy == "skipping":
        #     # prev_latent_list: `z_{t_{i - 1}}`
        #     # prev_latent_list.shape = (batch_size, 4, latent_height, latent_width)
        #     # prev_latent_list = [
        #     #     node.parent.get_state(
        #     #         sample_idx_list = sample_idx
        #     #     )[0] \
        #     #         for sample_idx in sample_idx_list
        #     # ]
        #     # prev_latent_list = torch.stack(prev_latent_list)

        #     if prev_latent_list is None:
        #         raise ValueError(
        #             f"`prev_latent_list` must be provided when `reward_shaping_policy = \"skipping\"`. " 
        #         )

        # # prev_potential_list = None
        # if self.reward_model.reward_shaping_policy == "potential_based":
        #     # prev_potential_list = torch.tensor(
        #     #     [
        #     #         node.parent.info_list[sample_idx].potential \
        #     #             for sample_idx in sample_idx_list
        #     #     ]
        #     # )

        #     if prev_potential_list is None:
        #         raise ValueError(
        #             f"`prev_potential_list` must be provided when `reward_shaping_policy = \"potential_based\"`. " 
        #         )

        # ---------= [Cal Intermediate Reward List] =---------
        (
            intermediate_reward_list, 
            pseudo_final_latent_list
        ) = self.reward_model.cal_intermediate_reward(
            # node = node, 
            sample_idx_list = sample_idx_list, 

            # prev_potential_list = prev_potential_list, 
            
            # # `prev_latent_list`: `z_{t_{i - 1}}`
            # prev_latent_list = prev_latent_list, 

            # `prev_action_list`: `\eta_{t_{i - 1}}`
            prev_action_list = prev_action_list, 

            # `state_list`: `z_{t_i}`
            latent_list = state_list, 

            prompt_list = prompt_list, 

            # `timestep_idx_list`: `i`
            timestep_idx_list = timestep_idx_list, 

            **self._cal_intermediate_reward_arg_dict
        )

        intermediate_reward_list = intermediate_reward_list.to(
            dtype = self.dtype, 
            device = self.device
        )
        
        intermediate_reward_list \
            = intermediate_reward_list.reshape(num_state, *self.reward_shape)
        
        pseudo_final_latent_list \
            = pseudo_final_latent_list.reshape(num_state, *pseudo_final_latent_list.shape[1: ])
        
        # if potential_list is not None:
        #     potential_list = potential_list.to(
        #         dtype = self.dtype, 
        #         device = self.device
        #     )
            
        #     potential_list \
        #         = potential_list.reshape(num_state, *self.reward_shape)

        # ---------= [Clean Up] =---------
        del prompt_list
        gc.collect()
        torch.cuda.empty_cache()

        # DiffusionMDP._nfe_cal_intermediate_reward += num_state

        # DiffusionMDP._nfe_cal_final_reward \
        #     += num_state * (1 + self.reward_model.use_difference_reward)
        
        # `batch_cal_intermediate_reward()` done
        return (
            intermediate_reward_list, 
            pseudo_final_latent_list
            # potential_list
        )


    def cal_final_reward(
        self, 

        sample_idx: int, 

        state: torch.Tensor, 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Compute the final reward for a terminal state.  

        Ret:
            `final_reward` (`float`): The final reward. 
        """

        sample_idx_list = [sample_idx]

        state_list = [state]

        # final_reward.shape = reward_shape
        final_reward = self.batch_cal_final_reward(
            sample_idx_list = sample_idx_list, 
            
            state_list = state_list, 

            **arg_dict
        )[0]
        
        # `cal_final_reward()` done
        return final_reward


    def batch_cal_final_reward(
        self, 

        sample_idx_list: List[int], 

        state_list: Union[torch.Tensor, List[torch.Tensor]], 

        **arg_dict: Optional[Dict]
    ) -> torch.Tensor:
        """
        Func:
            Batch compute the final rewards for the terminal states. 

        Ret:
            `final_reward_list` (`torch.Tensor` or `np.ndarray`): The list of the final rewards. 
        """

        if isinstance(state_list, list):
            state_list = torch.stack(
                state_list, 
                dim = 0
            )

        num_state = state_list.shape[0]

        # ---------= [Cal Final Reward List] =---------
        prompt_list = [
            self.prompt_list[sample_idx // self._num_sample_per_prompt] \
                for sample_idx in sample_idx_list
        ]
        
        final_reward_list = self.reward_model.cal_final_reward(
            sample_idx_list = sample_idx_list, 

            # `state_list`: `z_{t_n}`
            latent_list = state_list, 

            prompt_list = prompt_list
        )

        final_reward_list = final_reward_list.to(
            dtype = self.dtype, 
            device = self.device
        )

        final_reward_list \
            = final_reward_list.reshape(num_state, *self.reward_shape)

        # `batch_cal_final_reward()` done
        return final_reward_list
