from util.logger import logger

from typing import Optional, Union, List

import numpy as np

import torch


class Info:
    def __init__(
        self, 

        state: Union[torch.Tensor, np.ndarray], 

        prev_action: Union[torch.Tensor, np.ndarray] = None, 

        merged_reward_to_root: float = 0.0, 

        # potential: float = None, 

        intermediate_reward: float = None, 
        pseudo_final_latent: torch.Tensor = None, 

        final_reward: float = None
    ):
        self._state = state

        self.prev_action = prev_action

        self.num_vis = 0
        self.value = 0.0

        self.merged_reward_to_root = merged_reward_to_root

        self.intermediate_reward = intermediate_reward
        self.pseudo_final_latent = pseudo_final_latent
        
        self.final_reward = final_reward

        # `__init__()` done
        pass

    
    def get_state(
        self
    ) -> Union[torch.Tensor, np.ndarray]:
        state = self._state

        # `get_state()` done
        return state


    def display_info(
        self, 

        display_state: bool = True, 
        display_prev_action: bool = True, 
        display_num_vis: bool = True, 
        display_value: bool = True, 
        display_merged_reward_to_root: bool = True, 
        display_intermediate_reward: bool = True, 
        display_pseudo_final_latent: bool = True, 
        display_final_reward: bool = True
    ):
        """
        Func:
            Display the info of the node. 
        """

        state = self._state

        prev_action = self.prev_action

        num_vis = self.num_vis
        value = self.value

        merged_reward_to_root = self.merged_reward_to_root

        intermediate_reward = self.intermediate_reward
        pseudo_final_latent = self.pseudo_final_latent

        final_reward = self.final_reward

        if display_state:
            logger(f"        state: {state}")

        if display_prev_action:
            logger(
                f"        prev_action: {prev_action:.4f}" if prev_action \
                    else f"        prev_action: None"
            )

        if display_num_vis:
            logger(f"        num_vis: {num_vis}")
        
        if display_value:
            logger(f"        value: {value}")
        
        if display_merged_reward_to_root:
            logger(
                f"        merged_reward_to_root: {merged_reward_to_root:.4f}" if merged_reward_to_root \
                    else f"        merged_reward_to_root: None"
            )

        if display_intermediate_reward:
            logger(
                f"        intermediate_reward: {intermediate_reward:.4f}" if intermediate_reward \
                    else f"        intermediate_reward: None"
            )
        
        if display_pseudo_final_latent:
            logger(f"        pseudo_final_latent: {pseudo_final_latent}")

        if display_final_reward:
            logger(
                f"        final_reward: {final_reward:.4f}" if final_reward \
                    else f"        final_reward: None"
            )

        # `display_info()` done
        pass
