from util.logger import logger

from typing import Optional, Union, Tuple, List

import numpy as np

import torch

import gc

from util.basic_util import get_attr


class Trajectory:
    def __init__(
        self, 

        state_list: Union[List[torch.Tensor], List[np.ndarray]] = None, 
        action_list: Union[List[torch.Tensor], List[np.ndarray]] = None, 
        reward_list: Union[List[torch.Tensor], List[np.ndarray]] = None, 

        reward_shape: Optional[Tuple] = (1, ), 

        dtype: Optional[str] = "float32", 
        device: Optional[str] = "cpu", 

        ver: Optional[str] = "torch",  # ["torch", "numpy"]
    ):
        self.ver = ver

        self.dtype = dtype
        if isinstance(self.dtype, str):
            self.dtype = get_attr(self.ver, self.dtype)

        self.device = device

        self.state_list = state_list
        self.action_list = action_list
        self.reward_list = reward_list

        self.accumulated_reward_list = []
    
        self.reward_shape = reward_shape

        # ---------= [Initialize Total Reward] =---------
        if self.ver == "torch":
            merged_reward = torch.tensor(
                float("-inf"), 

                dtype = self.dtype, 
                device = self.device
            )
        elif self.ver == "numpy":
            merged_reward = np.array(
                float("-inf"), 
                
                dtype = self.dtype
            )

        self.merged_reward = merged_reward.reshape(reward_shape)

        self.incomplete = None

        # `__init__()` done
        pass


    def update_accumulated_reward_list(
        self, 

        mdp_modeling: str = "cumulative_reward"  # ["cumulative_reward", "max_reward"]
    ):
        """
        Func:
            Update `self.merged_reward` with `self.reward_list`. 
        """

        if (self.reward_list is None) \
            or (len(self.reward_list) <= 0):

            return

        if self.accumulated_reward_list is not None:
            # ---------= [Clean Up] =---------
            del self.accumulated_reward_list
            gc.collect()
            if self.ver == "torch":
                torch.cuda.empty_cache()
    
        # num_reward = len(self.reward_list) - 1
        num_reward = len(self.reward_list)

        self.accumulated_reward_list = [None] * num_reward
        
        for i in range(num_reward):
            self.reward_list[i] = self.reward_list[i].cpu()
            
            if i == 0:
                if self.ver == "torch":
                    self.accumulated_reward_list[i] = self.reward_list[i].clone()
                elif self.ver == "numpy":
                    self.accumulated_reward_list[i] = self.reward_list[i].copy()
            else:
                if mdp_modeling in [
                    "sparse_reward", 
                    "cumulative_reward"
                ]:
                    self.accumulated_reward_list[i] \
                        = self.accumulated_reward_list[i - 1] + self.reward_list[i]
                elif mdp_modeling == "max_reward":
                    self.accumulated_reward_list[i] \
                        = max(self.accumulated_reward_list[i - 1], self.reward_list[i])
            
            # if i == num_reward - 1:
            #     if mdp_modeling == "cumulative_reward":
            #         self.accumulated_reward_list[i] += self.reward_list[i + 1]
            #     elif mdp_modeling == "max_reward":
            #         self.accumulated_reward_list[i] = max(
            #             self.accumulated_reward_list[i], 
            #             self.reward_list[i + 1]
            #         )

            # goto `for i`
            pass
                
        self.merged_reward = self.accumulated_reward_list[-1]

        # `update_accumulated_reward_list()` done
        pass


    def concat_trajectory_by_list(
        self, 

        state_list: Union[List[torch.Tensor], List[np.ndarray]] = None, 
        action_list: Union[List[torch.Tensor], List[np.ndarray]] = None, 
        reward_list: Union[List[torch.Tensor], List[np.ndarray]] = None
    ):
        """
        Func:
            Concatenate a trajectory after `self`. 
        """

        if state_list:
            self.state_list += state_list
        
        if action_list:
            self.action_list += action_list

        if reward_list:
            self.reward_list += reward_list

        # `concat_trajectory_by_list()` done
        pass


    def concat_trajectory_by_trajectory(
        self, 

        trajectory: "Trajectory"
    ):
        """
        Func:
            Concatenate a trajectory after `self`. 
        """

        self.concat_trajectory_by_list(
            state_list = trajectory.state_list, 
            action_list = trajectory.action_list, 
            reward_list = trajectory.reward_list
        )

        # `concat_trajectory_by_trajectory()` done
        pass


    def display_trajectory(
        self, 

        display_state: bool = True, 
        display_action: bool = True, 
        display_reward: bool = True
    ):
        """
        Func:
            Display the trajectory `self` step by step. 
        """

        if self.state_list is None:
            return

        time_horizon = len(self.action_list)
        
        for timestep_idx in range(time_horizon):
            logger(f"[Timestep Index {timestep_idx}]")

            if display_state:
                state = self.state_list[timestep_idx]

                logger(f"    state: {state}")
            
            if display_action:
                action = self.action_list[timestep_idx]
                
                logger(f"    performs action: {action}")

            if display_reward:
                reward = self.reward_list[timestep_idx]
            
                logger(f"    gains intermediate reward: {reward.item():.4f}")

            # goto `for timestep_idx`
            pass

        if display_reward:
            final_reward = self.reward_list[-1]

            logger(f"    gains final reward: {final_reward.item():.4f}")

        # `display_trajectory()` done
        pass


    def get_trajectory_to_root(
        self, 

        node: "MCTSNode", 
        sample_idx: int, 

        include_final_reward: Optional[bool] = False
    ) -> Tuple[
        List[Union[torch.Tensor, np.ndarray]], 
        List[Union[torch.Tensor, np.ndarray]], 
        List[Union[torch.Tensor, np.ndarray]]
    ]:
        """
        NB: 
            The `reward_list` within the derived trajectory defaultly excludes the final reward. 
            If `node` is a terminal node, set `include_final_reward = True` to include the final reward. 

        Func:
            Get the trajectory from the root to `node`. 

        Ret:
            `trajectory_to_root` (`Trajectory`): The trajectory from the root to `node`. 
        """

        info = node.info_list[sample_idx]

        state_list = []
        action_list = []
        reward_list = []

        if include_final_reward:
            final_reward = info.final_reward
            
            if final_reward is None:
                raise ValueError(
                    f"Can not include the final reward of a non-terminal node. "
                )
            else:
                if self.ver == "torch":
                    if not isinstance(final_reward, torch.Tensor):
                        final_reward = torch.tensor(
                            final_reward, 

                            dtype = info._state.dtype, 
                            device = info._state.device
                        )
                elif self.ver == "numpy":
                    if not isinstance(final_reward, np.ndarray):
                        final_reward = np.array(
                            final_reward, 

                            dtype = info._state.dtype
                        )

                final_reward = final_reward.reshape(self.reward_shape)

                reward_list.append(final_reward)
        
        while node.parent:
            state_list.append(
                node.get_state(
                    sample_idx_list = sample_idx
                )[0]
            )
            action_list.append(info.prev_action)

            if info.final_reward is None:
                reward_list.append(info.intermediate_reward)

            node = node.parent
            info = node.info_list[sample_idx]

            # goto `while node.parent`
            pass
        
        # root
        if self.ver == "torch":
            init_state = node.get_state(
                sample_idx_list = sample_idx
            )[0].clone()
        elif self.ver == "numpy":
            init_state = node.get_state(
                sample_idx_list = sample_idx
            )[0].copy()

        state_list.append(init_state)

        state_list.reverse()
        action_list.reverse()
        reward_list.reverse()

        # `get_trajectory_to_root()` done
        return (
            state_list, 
            action_list, 
            reward_list
        )
        
        trajectory_to_root = Trajectory(
            state_list = state_list, 
            action_list = action_list, 
            reward_list = reward_list, 
            
            ver = ver
        )

        # `get_trajectory_to_root()` done
        return trajectory_to_root
    