from util.logger import logger

from typing import Optional, Union, List

import numpy as np

import torch

from .info import Info


class BSNode:
    _node_idx = 0


    @staticmethod
    def get_new_node_idx(
    ) -> int:
        """
        Func:
            Get an auto-incrementing node index. 

        Ret:
            `new_node_idx` (`int`): A new node index. 
        """

        new_node_idx = BSNode._node_idx
        BSNode._node_idx += 1
        
        # `get_new_node_idx()` done
        return new_node_idx


    def __init__(
        self, 

        bs_instance: "BeamSearch", 

        sample_idx_list: List[int], 

        state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ], 

        prev_action_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ] = None, 

        parent: "BSNode" = None, 

        merged_reward_to_root_list: List[float] = 0.0, 

        # potential_list: Optional[List[float]] = None, 

        intermediate_reward_list: Optional[List[float]] = None, 
        pseudo_final_latent_list: Optional[List[torch.Tensor]] = None, 

        final_reward_list: Optional[List[float]] = None, 

        device: Optional[str] = "cpu"
    ):
        self.bs_instance = bs_instance
        
        self.device = device

        self.node_idx = BSNode.get_new_node_idx()

        self.parent = parent
        
        self.depth = 0
        if self.parent is not None:
            self.depth = self.parent.depth + 1

        self.children_list = []

        self.info_list = [None] * self.bs_instance.num_sample
        
        self._prepare_info_list(
            sample_idx_list = sample_idx_list, 
            
            state_list = state_list, 
            prev_action_list = prev_action_list, 

            intermediate_reward_list = intermediate_reward_list, 
            pseudo_final_latent_list = pseudo_final_latent_list, 

            final_reward_list = final_reward_list
        )

        # `__init__()` done
        pass


    def __eq__(
        self, 
        other
    ) -> bool:
        if isinstance(other, BSNode):
            return self.node_idx == other.node_idx
        
        # `__eq__()` done
        return False


    def __hash__(
        self
    ) -> int:
        hash_val = hash(self.node_idx)

        # `__hash__()` done
        return hash_val


    def _prepare_info_list(
        self, 

        sample_idx_list: List[int], 

        state_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ], 

        prev_action_list: Union[
            torch.Tensor, List[torch.Tensor], 
            np.ndarray, List[np.ndarray]
        ] = None, 

        intermediate_reward_list: Optional[List[float]] = None, 
        pseudo_final_latent_list: Optional[List[torch.Tensor]] = None, 

        final_reward_list: Optional[List[float]] = None
    ):
        """
        Func:
            Prepare `self.info_list`. 
        """

        # ---------= [Prepare Everything] =---------
        num_state = len(state_list)
        
        if prev_action_list is None:
            prev_action_list = [None] * num_state
        
        if intermediate_reward_list is None:
            intermediate_reward_list = [None] * num_state
        
        if pseudo_final_latent_list is None:
            pseudo_final_latent_list = [None] * num_state

        if final_reward_list is None:
            final_reward_list = [None] * num_state
            
        # ---------= [Get `info_dict`] =---------
        for (
            sample_idx, 
            
            state, 

            prev_action, 

            intermediate_reward, 
            pseudo_final_latent, 

            final_reward
        ) in zip(
            sample_idx_list, 
            
            state_list, 

            prev_action_list, 

            intermediate_reward_list, 
            pseudo_final_latent_list, 

            final_reward_list
        ):
            self.info_list[sample_idx] = Info(
                state = state, 

                prev_action = prev_action, 

                intermediate_reward = intermediate_reward, 
                pseudo_final_latent = pseudo_final_latent, 
                
                final_reward = final_reward
            )

            # goto `for`
            pass

        # `_prepare_info_list()` done
        pass


    def get_state(
        self, 

        sample_idx_list: Union[int, List[int]] = None
    ) -> Union[List[torch.Tensor], List[np.ndarray]]:
        """
        Func:
            Get the `state` of the node. 
            Ensure on GPU if it is `torch.Tensor`. 

        Ret:
            `state_list` (`List[torch.Tensor]` or `List[np.ndarray]`): The list of `state`s of the node.
                (`sample_idx_list` is provided) state_list.shape = (len(sample_idx_list), ). 
                (`sample_idx_list` is not provided) state_list.shape = (num_sample, ). 
        """
        
        # ---------= [Prepare `sample_idx_list`] =---------
        if sample_idx_list is None:
            sample_idx_list = list(range(self.bs_instance.num_sample))
        
        if not isinstance(sample_idx_list, list):
            sample_idx_list = [sample_idx_list]

        # ---------= [Get State] =---------
        state_list = []

        for sample_idx in sample_idx_list:
            info = self.info_list[sample_idx]

            if info is None:
                state_list.append(None)
                
                continue

            state = info.get_state()

            if isinstance(state, torch.Tensor):
                if not state.is_cuda:
                    state.to(self.device)

            state_list.append(state)
            
            # goto `for sample_idx`
            pass

        # `state()` done
        return state_list


    def display_info(
        self, 

        display_state: bool = True, 
        display_prev_action: bool = True, 
        display_intermediate_reward: bool = True, 
        display_pseudo_final_latent: bool = True, 
        display_final_reward: bool = True
    ):
        node_idx = self.node_idx
        depth = self.depth
        children_list = self.children_list

        logger(f"[Node {node_idx}]")
        logger(f"    depth: {depth}")
        logger(f"    children_list: {children_list}")

        for sample_idx, info in enumerate(self.info_list):
            logger(f"    [Sample {sample_idx}]")

            info.display_info(
                display_state = display_state, 
                display_prev_action = display_prev_action, 
                display_intermediate_reward = display_intermediate_reward, 
                display_pseudo_final_latent = display_pseudo_final_latent, 
                display_final_reward = display_final_reward
            )

            # goto `for sample_idx, info`
            pass

        # `display_info()` done
        pass
