from util.logger import logger

from typing import Optional, Union, List, Dict, Tuple

import numpy as np

import torch

from tqdm import tqdm

import time

import gc

from pathlib import Path

from types import MethodType

from util.torch_util import tsfm_to_1d_array
from util.image_util import save_pil_as_png
from util.yaml_util import (
    save_yaml, 
    convert_numpy_type_to_native_type
)
from util.pipeline_util import img_latent_to_pil

from OT_MCTS.src.beam_search.bs_node import BSNode
from OT_MCTS.src.beam_search.bs import BeamSearch

from .mdp.diffusion_mdp import DiffusionMDP


class DiffusionOTBS(BeamSearch):
    def __init__(
        self, 

        is_eps_action: Optional[bool] = False, 

        # ---------= [Beam Search] =---------
        num_beam: Optional[int] = 4, 
        num_candidate_per_beam: Optional[int] = 2,

        mdp: DiffusionMDP = None, 

        init_state_list: Union[torch.Tensor, List[torch.Tensor]] = None, 

        # ---------= [Mode] =---------
        mdp_modeling: str = "max_reward", 

        # ---------= [Expansion Policy] =---------
        expansion_action_sampling_policy: str = "uniform", 

        # ---------= [NFE Limit] =---------
        nfe_cal_dynamics_lim: int = None, 
        nfe_cal_intermediate_reward_lim: int = None, 
        nfe_cal_final_reward_lim: int = None, 

        # ---------= [LRU Cache] =---------
        lru_cache: "LRUCache" = None, 

        # ---------= [Dtype] =---------
        dtype: Optional[str] = "float32", 

        # ---------= [Save Root Path] =---------
        expansion_policy_root_path: Optional[Union[str, Path]] = None, 
        folder_name_list: List[str] = None, 
        cfg_dict: Dict = None, 

        **arg_dict: Optional[Dict]
    ):
        # ---------= [LRU Cache] =---------
        self.lru_cache = lru_cache

        super().__init__(
            is_eps_action = is_eps_action, 

            # ---------= [Beam Search] =---------
            num_beam = num_beam, 
            num_candidate_per_beam = num_candidate_per_beam, 

            mdp = mdp, 

            # ---------= [Expansion Policy] =---------
            expansion_action_sampling_policy = expansion_action_sampling_policy, 

            init_state_list = init_state_list, 

            # ---------= [Mode] =---------
            mdp_modeling = mdp_modeling, 

            dtype = dtype, 

            **arg_dict
        )

        # ---------= [Best Trajectory Updated BS Loop Index List List] =---------
        self.best_trajectory_updated_bs_loop_idx_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        self.best_trajectory_updated_wall_clock_time_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        self.best_trajectory_updated_nfe_cal_dynamics_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]
        self.best_trajectory_updated_nfe_cal_intermediate_reward_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]
        self.best_trajectory_updated_nfe_cal_final_reward_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        # ---------= [NFEs] =---------
        self._nfe_cal_dynamics_lim = nfe_cal_dynamics_lim
        self._nfe_cal_intermediate_reward_lim = nfe_cal_intermediate_reward_lim
        self._nfe_cal_final_reward_lim = nfe_cal_final_reward_lim

        self._nfe_cal_dynamics_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        self._nfe_cal_intermediate_reward_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        self._nfe_cal_final_reward_list_list = [
            [] \
                for _ in range(self.num_sample)
        ]

        # ---------= [Folder Root Path] =---------
        if isinstance(expansion_policy_root_path, str):
            expansion_policy_root_path = Path(expansion_policy_root_path)
        self._expansion_policy_root_path = expansion_policy_root_path

        self._folder_root_path_list = [
            expansion_policy_root_path / folder_name
                for folder_name in folder_name_list
        ]
        
        num_prompt = len(self._folder_root_path_list)
        self._num_sample_per_prompt = self.num_sample // num_prompt

        self._save_png_root_path_list = []
        # self._save_action_list_root_path_list = []
        self._save_result_root_path_list = []

        for folder_root_path in self._folder_root_path_list:
            self._save_png_root_path_list.append(folder_root_path / "png")
            # self._save_action_list_root_path_list.append(folder_root_path / "action_list")
            self._save_result_root_path_list.append(folder_root_path / "result")

            # goto `for folder_root_path`
            pass

        # ---------= [Save Task Cfg] =---------
        if self.mdp.optimized_prompt_list is None:
            self.mdp.optimized_prompt_list = [None] * num_prompt

        for (prompt, optimized_prompt, folder_root_path) in zip(
            self.mdp.prompt_list, 
            self.mdp.optimized_prompt_list, 

            self._folder_root_path_list
        ):
            cfg_dict["sample"]["prompt"] = prompt

            if optimized_prompt is not None:
                cfg_dict["sample"]["optimized_prompt"] = optimized_prompt

            save_yaml(
                cfg_dict, 

                yaml_root_path = folder_root_path, 
                yaml_filename = "cfg.yaml"
            )

            # goto `for prompt`
            pass

        # `__init__()` done
        pass

    
    def _get_nfe(
        self, 

        sample_idx: int
    ) -> Tuple[int, int, int]:
        """
        Func: 
            Get the NFEs of Sample `sample_idx`. 

        Ret: 
            The NFEs of Sample `sample_idx`. 
                `nfe_cal_dynamics` (`int`). 
                `nfe_cal_intermediate_reward` (`int`). 
                `nfe_cal_final_reward` (`int`). 
        """

        nfe_cal_dynamics = self.mdp.reward_model.nfe_cal_dynamics_list[sample_idx]
        nfe_cal_intermediate_reward = self.mdp.reward_model.nfe_cal_intermediate_reward_list[sample_idx]
        nfe_cal_final_reward = self.mdp.reward_model.nfe_cal_final_reward_list[sample_idx]

        # `_get_nfe()` done
        return (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        )


    def _get_is_time_to_stop(
        self, 

        sample_idx: int
    ) -> bool:
        """
        Func:
            Simulate the trajectory from the expanded node to terminal state. 

        Ret:
            `is_time_to_stop` (`bool`): Whether it is time to stop. 
        """

        (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        ) = self._get_nfe(sample_idx = sample_idx)

        is_time_to_stop = (nfe_cal_dynamics >= self._nfe_cal_dynamics_lim) \
            or (nfe_cal_intermediate_reward >= self._nfe_cal_intermediate_reward_lim) \
            or (nfe_cal_final_reward >= self._nfe_cal_final_reward_lim)

        # `get_is_time_to_stop()` done
        return is_time_to_stop
    

    def _get_is_cost_legal(
        self, 

        sample_idx: int
    ) -> bool:
        """
        Func:
            Determine whether the cost of MCTS is legal. 

        Ret:
            `is_cost_legal` (`bool`): Whether the cost of MCTS is legal. 
        """

        (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        ) = self._get_nfe(sample_idx = sample_idx)

        # `get_is_time_to_stop()` done
        is_cost_legal = (nfe_cal_dynamics <= self._nfe_cal_dynamics_lim) \
            and (nfe_cal_intermediate_reward <= self._nfe_cal_intermediate_reward_lim) \
            and (nfe_cal_final_reward <= self._nfe_cal_final_reward_lim)
        
        # `_get_is_cost_legal()` done
        return is_cost_legal


    @torch.no_grad()
    def _best_trajectory_updated_callback(
        self, 

        caller: str, 
        pseudo: bool, 

        node: BSNode, 

        sample_idx: int, 

        timestep_idx: int
    ):
        """
        Func:
            The callback function after the best trajectory is updated. 
        """
        
        best_merged_reward_list = self.history_best_merged_reward_list_list[sample_idx]

        second_last_merged_reward = best_merged_reward_list[-2]
        last_merged_reward = best_merged_reward_list[-1]
        
        logger(
            f"[Sample {sample_idx}] Best trajectory updated by `{caller}` at node {node.node_idx} for sample {sample_idx}. "
        )
        logger(
            f"    merged_reward: [{second_last_merged_reward:.4f}] -> [{last_merged_reward:.4f}]"
        )

        # ---------= [Record Best Trajectory Updated] =---------
        self.best_trajectory_updated_bs_loop_idx_list_list[sample_idx].append(self._bs_loop_idx)

        time_ed = time.time()
        time_cost = time_ed - self._time_st

        self.best_trajectory_updated_wall_clock_time_list_list[sample_idx].append(time_cost)

        (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        ) = self._get_nfe(sample_idx = sample_idx)

        self.best_trajectory_updated_nfe_cal_dynamics_list_list[sample_idx].append(nfe_cal_dynamics)
        self.best_trajectory_updated_nfe_cal_intermediate_reward_list_list[sample_idx].append(nfe_cal_intermediate_reward)
        self.best_trajectory_updated_nfe_cal_final_reward_list_list[sample_idx].append(nfe_cal_final_reward)

        # ---------= [Save Sample] =---------
        prompt_idx = sample_idx // self._num_sample_per_prompt
        true_sample_idx = sample_idx % self._num_sample_per_prompt
        img_pil_idx = len(self.history_best_merged_reward_list_list[sample_idx]) - 2

        # img_latent_list.shape = (1, latent_num_channel, latent_height, latent_width)
        if pseudo:
            img_latent_list = node.info_list[sample_idx].pseudo_final_latent
        else:
            img_latent_list = node.get_state(
                sample_idx_list = sample_idx
            )

        if not isinstance(img_latent_list, torch.Tensor):
            img_latent_list = torch.stack(img_latent_list)
        
        img_latent_list = img_latent_list.to(
            self.device
        )

        if len(img_latent_list.shape) < 4:
            img_latent_list = img_latent_list.unsqueeze(0)
        
        img_pil = img_latent_to_pil(
            img_latent_list = img_latent_list, 
            pipeline = self.mdp.pipeline
        )[0]

        save_pil_as_png(
            pil = img_pil, 

            png_root_path = self._save_png_root_path_list[prompt_idx] / f"{true_sample_idx}", 
            png_filename = f"{true_sample_idx}_{img_pil_idx}.png"
        )

        # `_best_trajectory_updated_callback()` done
        pass


    @torch.no_grad()
    def _get_new_node(
        self, 

        sample_idx_list: List[int], 

        state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ], 

        prev_action_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ] = None, 

        parent: BSNode = None
    ) -> BSNode:
        """
        Func:
            Get a new node and compute its intermediate reward. 
            If terminal, compute its final reward, too. 

        Ret:
            `new_node` (`BSNode`): The created node. 
        """

        # ---------= [Prepare `state_list`] =---------
        if isinstance(state_list, list):
            state_list = torch.stack(state_list)

        state_list = state_list.clone()

        # ---------= [New Node] =---------
        new_node = BSNode(
            bs_instance = self, 

            sample_idx_list = sample_idx_list, 

            state_list = state_list, 

            prev_action_list = prev_action_list, 

            parent = parent, 

            device = self.device
        )

        # ---------= [LRU] =---------
        def _get_state(
            self, 

            sample_idx_list: Union[int, List[int]] = None
        ) -> Union[List[torch.Tensor], List[np.ndarray]]:
            """
            Func:
                Get the `state` of the node. 
                Ensure on GPU if it is `torch.Tensor`. 

            Ret:
                `state_list` (`List[torch.Tensor]` or `List[np.ndarray]`): The list of `state`s of the node.
                    (`sample_idx_list` is provided) state_list.shape = (len(sample_idx_list), ). 
                    (`sample_idx_list` is not provided) state_list.shape = (num_sample, ). 
            """

            # ---------= [Prepare `sample_idx_list`] =---------
            if sample_idx_list is None:
                sample_idx_list = list(
                    range(self.mcts_instance.num_sample)
                )
            
            if not isinstance(sample_idx_list, list):
                sample_idx_list = [sample_idx_list]

            # ---------= [Get State] =---------
            state_list = []
            
            for sample_idx in sample_idx_list:
                info = self.info_list[sample_idx]

                if info is None:
                    state_list.append(None)
                    
                    continue

                self.bs_instance.lru_cache.access(
                    self, 
                    sample_idx = sample_idx
                )
    
                state = info.get_state()
                state_list.append(state)
                
                # goto `for sample_idx`
                pass

            # `state()` done
            return state_list


        new_node.get_state = MethodType(
            _get_state, 
            new_node
        )

        # `_get_new_node()` done
        return new_node


    def _run_sample_pre_process(
        self, 

        **arg_dict
    ):
        """
        Func:
            The function called before starting a run of BS for a single sample. 
        """

        self._time_st = time.time()

        # `_run_sample_pre_process()` done
        pass


    def _run_sample_post_process(
        self, 

        sample_idx: int, 

        **arg_dict
    ):
        """
        Func:
            The function called after finishing a run of BS for a single sample. 
        """

        # ---------= [Compute Wall-clock Time Cost] =---------
        time_ed = time.time()

        time_cost = time_ed - self._time_st
        self._wall_clock_time_cost_list[sample_idx] = time_cost

        logger(
            f"[Sample {sample_idx}] Finished, wall-clock time cost: {round(time_cost)} second(s). "
        )

        # ---------= [Total Cost] =---------
        bs_loop_idx = self._bs_loop_idx

        nfe_cal_dynamics_list = self._nfe_cal_dynamics_list_list[sample_idx]
        nfe_cal_intermediate_reward_list = self._nfe_cal_intermediate_reward_list_list[sample_idx]
        nfe_cal_final_reward_list = self._nfe_cal_final_reward_list_list[sample_idx]

        # ---------= [Best Trajectory Updated] =---------
        best_trajectory_updated_bs_loop_idx_list \
            = self.best_trajectory_updated_bs_loop_idx_list_list[sample_idx]
        
        best_trajectory_updated_wall_clock_time_list \
            = self.best_trajectory_updated_wall_clock_time_list_list[sample_idx]

        best_trajectory_updated_nfe_cal_dynamics_list \
            = self.best_trajectory_updated_nfe_cal_dynamics_list_list[sample_idx]
        best_trajectory_updated_nfe_cal_intermediate_reward_list \
            = self.best_trajectory_updated_nfe_cal_intermediate_reward_list_list[sample_idx]
        best_trajectory_updated_nfe_cal_final_reward_list \
            = self.best_trajectory_updated_nfe_cal_final_reward_list_list[sample_idx]

        # ---------= [Merged Reward] =---------
        best_merged_reward_list = self.history_best_merged_reward_list_list[sample_idx][1: ]
        best_merged_reward_list = [
            best_merged_reward if isinstance(best_merged_reward, float) \
                else best_merged_reward.item() \
                    for best_merged_reward in best_merged_reward_list
        ]

        bs_loop_wall_time_cost_list = self._bs_loop_wall_time_cost_list

        # ---------= [Save Result] =---------
        result_dict = {
            "bs_loop_idx": bs_loop_idx, 

            "wall_clock_time_cost": time_cost, 

            # ---------= [NFE] =---------
            "nfe_cal_dynamics_list": nfe_cal_dynamics_list, 
            "nfe_cal_intermediate_reward_list": nfe_cal_intermediate_reward_list, 
            "nfe_cal_final_reward_list": nfe_cal_final_reward_list, 

            # ---------= [Best Trajectory Updated] =---------
            "best_trajectory_updated_bs_loop_idx_list": best_trajectory_updated_bs_loop_idx_list, 
            "best_trajectory_updated_wall_clock_time_list": best_trajectory_updated_wall_clock_time_list, 

            "best_trajectory_updated_nfe_cal_dynamics_list": best_trajectory_updated_nfe_cal_dynamics_list, 
            "best_trajectory_updated_nfe_cal_intermediate_reward_list": best_trajectory_updated_nfe_cal_intermediate_reward_list, 
            "best_trajectory_updated_nfe_cal_final_reward_list": best_trajectory_updated_nfe_cal_final_reward_list, 
            
            # ---------= [Merged Reward] =---------
            "best_merged_reward_list": best_merged_reward_list, 

            # ---------= [Misc] =---------
            "bs_loop_wall_time_cost_list": bs_loop_wall_time_cost_list
        }
        result_dict = convert_numpy_type_to_native_type(result_dict)

        prompt_idx = sample_idx // self._num_sample_per_prompt
        true_sample_idx = sample_idx % self._num_sample_per_prompt

        save_yaml(
            result_dict, 

            yaml_root_path = self._save_result_root_path_list[prompt_idx], 
            yaml_filename = f"{true_sample_idx}.yaml"
        )

        # ---------= [Clean Up] =---------
        del nfe_cal_dynamics_list, nfe_cal_intermediate_reward_list, nfe_cal_final_reward_list
        del best_trajectory_updated_bs_loop_idx_list
        del best_trajectory_updated_wall_clock_time_list
        del best_trajectory_updated_nfe_cal_dynamics_list, best_trajectory_updated_nfe_cal_intermediate_reward_list, best_trajectory_updated_nfe_cal_final_reward_list
        del bs_loop_wall_time_cost_list
        gc.collect()

        # `_run_sample_post_process()` done
        pass


    @torch.no_grad()
    def _bs_loop_callback(
        self, 

        local_var_dict: Dict, 

        **arg_dict
    ):
        """
        Func:
            The callback function after each BS loop. 
        """

        time_ed = time.time()

        time_cost = time_ed - self._time_st

        self._bs_loop_wall_time_cost_list.append(time_cost)

        self._bs_loop_idx += 1

        # ---------= [NFEs] =---------
        for sample_idx in range(self.num_sample):
            (
                nfe_cal_dynamics, 
                nfe_cal_intermediate_reward, 
                nfe_cal_final_reward
            ) = self._get_nfe(sample_idx = sample_idx)

            self._nfe_cal_dynamics_list_list[sample_idx].append(nfe_cal_dynamics)
            self._nfe_cal_intermediate_reward_list_list[sample_idx].append(nfe_cal_intermediate_reward)
            self._nfe_cal_final_reward_list_list[sample_idx].append(nfe_cal_final_reward)

            logger(
                f"[Sample {sample_idx}] "
                f"nfe_cal_dynamics: {nfe_cal_dynamics}, "
                f"nfe_cal_intermediate_reward: {nfe_cal_intermediate_reward}, "
                f"nfe_cal_final_reward: {nfe_cal_final_reward}"
            )

            # goto `for sample_idx`
            pass

        # `_bs_loop_callback()` done
        pass


    def display_sample_result(
        self, 

        sample_idx: int
    ):
        """
        Func:
            Display the search results. 
        """

        best_merged_reward = self.history_best_merged_reward_list_list[sample_idx][-1]

        logger(f"[Main Results]")
        logger(f"    sample_idx: {sample_idx}")
        logger(f"    best_merged_reward: {best_merged_reward}")

        (
            nfe_cal_dynamics, 
            nfe_cal_intermediate_reward, 
            nfe_cal_final_reward
        ) = self._get_nfe(sample_idx = sample_idx)

        logger(
            f"        nfe_cal_dynamics: {nfe_cal_dynamics}, "
            f"nfe_cal_intermediate_reward: {nfe_cal_intermediate_reward}, "
            f"nfe_cal_final_reward: {nfe_cal_final_reward}"
        )

        # `display_sample_result()` done
        pass


    @torch.no_grad()
    def _sample_expansion_action(
        self, 
        
        node: BSNode, 

        sample_idx: int
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Sample an action from the action space for `self._expand()`. 

        Ret:
            `action` (`torch.Tensor` or `np.ndarray`): The sampled action. 
        """

        action_sampling_policy = self.expansion_action_sampling_policy

        if action_sampling_policy == "uniform":
            action = self.mdp.action_space.sample_uniform_element()
        else:
            raise NotImplementedError(
                f"Unsupported `action_sampling_policy`, got `{action_sampling_policy}`. "
            )

        # `_sample_expansion_action()` done
        return action


    @torch.no_grad()
    def _batch_sample_expansion_action(
        self, 
        
        node: BSNode, 

        sample_idx_list: List[int]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Batch sample actions from the action space for `self._expand()`. 

        Ret:
            `action_list` (`Union[torch.Tensor, np.ndarray]`): The batch of the sampled actions. 
        """
        
        action_list = [
            self._sample_expansion_action(
                node = node, 

                sample_idx = sample_idx
            ) \
                for sample_idx in sample_idx_list
        ]

        action_list = torch.stack(
            action_list, 
            dim = 0
        )

        # `_batch_sample_expansion_action()` done
        return action_list
