from util.logger import logger

from typing import Optional, Tuple, Union, List, Dict

from abc import ABC, abstractmethod

import numpy as np

import torch

from util.basic_util import get_attr

from .trajectory import Trajectory


class MarkovDecisionProcess(ABC):
    def __init__(
        self, 

        state_space: "Space", 
        action_space: "Space", 
        time_horizon: int, 

        # ---------= [Parallel] =---------
        # cal_dynamics_batch_size: Optional[int] = 1, 
        # cal_intermediate_reward_batch_size: Optional[int] = 1, 
        # cal_final_reward_batch_size: Optional[int] = 1, 

        reward_shape: Optional[Tuple] = (1, )
    ):
        self.state_space = state_space

        self.ver = self.state_space.ver

        self.dtype = self.state_space.dtype
        self.device = self.state_space.device
        
        self.action_space = action_space
        self.time_horizon = time_horizon

        # ---------= [Parallel] =---------
        # self.cal_dynamics_batch_size = cal_dynamics_batch_size
        # self.cal_intermediate_reward_batch_size = cal_intermediate_reward_batch_size
        # self.cal_final_reward_batch_size = cal_final_reward_batch_size

        self.reward_shape = reward_shape

        # ---------= [Prams for Some Compute Intermediate Reward Policy] =---------
        self._cal_intermediate_reward_arg_dict = {}

        # `__init__()` done
        pass


    @abstractmethod
    def cal_dynamics(
        self, 

        state: Union[torch.Tensor, np.ndarray], 
        action: Union[torch.Tensor, np.ndarray], 

        **arg_dict: Optional[Dict]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Compute the next state given the current state and action. 
        """

        # `cal_dynamics()` done
        pass

    
    @abstractmethod
    def batch_cal_dynamics(
        self, 

        state_list: Union[
            Union[List[torch.Tensor], List[np.ndarray]], 
            Union[torch.Tensor, np.ndarray]
        ], 
        action_list: Union[
            Union[List[torch.Tensor], List[np.ndarray]], 
            Union[torch.Tensor, np.ndarray]
        ], 

        **arg_dict: Optional[Dict]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Batch compute the next states given the current states and actions. 

        Ret:
            `next_state_list` (`torch.Tensor` or `np.ndarray`): The list of the next states. 
        """

        # `batch_cal_dynamics()` done
        pass


    @abstractmethod
    def cal_intermediate_reward(
        self, 

        state: Union[torch.Tensor, np.ndarray], 
        action: Union[torch.Tensor, np.ndarray], 

        **arg_dict: Optional[Dict]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Compute the intermediate reward for a given state and action. 
        """

        # `cal_intermediate_reward()`
        pass


    @abstractmethod
    def batch_cal_intermediate_reward(
        self, 

        state_list: Union[
            Union[List[torch.Tensor], List[np.ndarray]], 
            Union[torch.Tensor, np.ndarray]
        ], 
        action_list: Union[
            Union[List[torch.Tensor], List[np.ndarray]], 
            Union[torch.Tensor, np.ndarray]
        ], 

        **arg_dict: Optional[Dict]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Compute the intermediate rewards for the given states and actions. 

        Ret:
            `intermediate_reward_list` (`torch.Tensor` or `np.ndarray`): The list of the intermediate rewards. 
        """

        # `batch_cal_intermediate_reward()` done
        pass


    @abstractmethod
    def cal_final_reward(
        self, 

        state: Union[torch.Tensor, np.ndarray], 

        **arg_dict: Optional[Dict]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        NB: 
            Abstract, should be implemented. 

        Func:
            Compute the final reward for a terminal state.  
        """

        # `cal_final_reward()` done
        pass


    @abstractmethod
    def batch_cal_final_reward(
        self, 

        state_list: Union[torch.Tensor, np.ndarray], 

        **arg_dict: Optional[Dict]
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Batch compute the final rewards for the terminal states. 

        Ret:
            `final_reward_list` (`torch.Tensor` or `np.ndarray`): The list of the final rewards. 
        """

        # `batch_cal_final_reward()` done
        pass


    def get_default_action(
        self
    ) -> Union[torch.Tensor, np.ndarray]:
        """
        Func:
            Get a default action from the action space. 

        Ret:
            `default_action` (`torch.Tensor` or `np.ndarray`): The default action. 
        """

        default_action = self.action_space.get_default_element()

        # `get_default_action()` done
        return default_action


    # (discarded)
    # def simulate_step_list(
    #     self, 

    #     state: Union[torch.Tensor, np.ndarray], 
    #     action_list: Union[List[torch.Tensor], List[np.ndarray]], 
        
    #     verbose: bool = False
    # ) -> Tuple[
    #     Union[List[torch.Tensor], List[np.ndarray]],  # state list
    #     Union[List[torch.Tensor], List[np.ndarray]]  # reward list
    # ]:
    #     """
    #     Func:
    #         Simulate the transitions of states from `state` with the sequential actions `action_list`. 

    #     Ret:
    #         `state_list` (`List`): States along the trajectory, with input `state` included. 
    #         `reward_list` (`List`): Rewards along the trajectory, with final reward included. 
    #     """

    #     if self.ver == "torch":
    #         state = state.clone()
    #     elif self.ver == "numpy":
    #         state = state.copy()

    #     state_list = [state]

    #     if isinstance(action_list, list):
    #         if self.ver == "torch":
    #             action_list = torch.stack(
    #                 action_list, 
    #                 dim = 0
    #             )
    #         elif self.ver == "numpy":
    #             action_list = np.stack(
    #                 action_list, 
    #                 axis = 0
    #             )

    #     action_list = self.action_space.batch_clamp(action_list)

    #     reward_list = []

    #     if verbose:
    #         logger(f"Start at")
    #         logger(f"    state: {state}")

    #     for action_idx, action in enumerate(action_list):
    #         if verbose:
    #             logger(f"[Action {action_idx}]")
    #             logger(f"    action: {action}")
            
    #         # cal intermediate reward
    #         intermedia_reward = self.cal_intermediate_reward(
    #             state = state, 
    #             action = action
    #         )

    #         reward_list.append(intermedia_reward)

    #         state = self.cal_dynamics(
    #             state = state, 
    #             action = action
    #         )

    #         state_list.append(state)

    #         if verbose:
    #             logger(f"    reward: {intermedia_reward}")

    #             logger(f"[State {action_idx + 1}]")
    #             logger(f"    state: {state}")
        
    #     # cal final reward
    #     final_reward = self.cal_final_reward(state = state)

    #     reward_list.append(final_reward)

    #     if verbose:
    #         logger(f"final_reward: {final_reward}")

    #     return state_list, reward_list


    def batch_simulate_step_list(
        self, 

        # state_list.shape = (batch_size, *state_shape)
        state_list: Union[torch.Tensor, np.ndarray], 

        # action_list_list.shape = (batch_size, time_horizon, *reward_shape)
        action_list_list: Union[
            List[List[torch.Tensor]], 
            List[List[np.ndarray]]
        ], 
        
        verbose: bool = False
    ) -> Tuple[
        Union[
            List[List[torch.Tensor]], 
            List[List[np.ndarray]]
        ],  # state list
        Union[
            List[List[torch.Tensor]], 
            List[List[np.ndarray]]
        ],  # reward list
    ]:
        """
        Func:
            Batch simulate the transitions of states from `state_list` with the sequential actions `action_list`. 

        Ret:
            `state_list` (`List`): States along the trajectory, with input `state` included. 
            `reward_list` (`List`): Rewards along the trajectory, with final reward included. 
        """

        if self.ver == "torch":
            state_list = state_list.clone()
        elif self.ver == "numpy":
            state_list = state_list.copy()

        state_list_list = [state_list]

        if isinstance(action_list_list, list):
            if self.ver == "torch":
                action_list_list = torch.stack(
                    action_list_list, 
                    dim = 0
                )
            elif self.ver == "numpy":
                action_list_list = np.stack(
                    action_list_list, 
                    axis = 0
                )

        # action_list_list.shape = (batch_size, time_horizon, *reward_shape)
        action_list_list = self.action_space.batch_clamp(action_list_list)

        batch_size, time_horizon, _ = action_list_list.shape

        reward_list_list = [
            [] \
                for _ in range(batch_size)
        ]

        if verbose:
            logger(f"Start at")
            logger(f"    state_list: {state_list}")

        for timestep_idx in range(time_horizon):
            action_list = action_list_list[:, timestep_idx]
            
            if verbose:
                logger(f"[Timestep Idx {timestep_idx}]")
                logger(f"    action_list: {action_list}")

            intermediate_reward_list = self.batch_cal_intermediate_reward(
                state_list = state_list, 
                action_list = action_list
            )

            for sample_idx in range(batch_size):
                reward_list_list[sample_idx].append(
                    intermediate_reward_list[sample_idx]
                )

                # goto `for sample_idx`
                pass

            if verbose:
                logger(f"intermediate_reward_list: {intermediate_reward_list}")

            state_list = self.batch_cal_dynamics(
                state_list = state_list, 
                action_list = action_list
            )

            for sample_idx in range(batch_size):
                reward_list_list[sample_idx].append(
                    intermediate_reward_list[sample_idx]
                )

                # goto `for sample_idx`
                pass

            state_list_list.append(state_list)

            # goto `for timestep_idx`
            pass

        final_reward_list = self.batch_cal_final_reward(
            state_list = state_list
        )
        
        if verbose:
            logger(f"final_reward_list: {final_reward_list}")

        for sample_idx in range(batch_size):
            reward_list_list[sample_idx].append(
                final_reward_list[sample_idx]
            )

            # goto `for sample_idx`
            pass

        return state_list_list, reward_list_list


    def get_nominal_trajectory(
        self, 

        init_state: Union[torch.Tensor, np.ndarray], 
        action_list: Union[
            Union[torch.Tensor, np.ndarray], 
            Union[List[torch.Tensor], List[np.ndarray]]
        ] = 0.5
    ) -> Trajectory:
        """
        Func:
            Get a nominal trajectory with default actions. 

        Ret: 
            `nominal_trajectory` (`Trajectory`): A trajectory from `init_state` with the sequentail
                actions `action_list`. 
        """

        if self.ver == "torch":
            from util.torch_util import tsfm_to_1d_array

            action_list = tsfm_to_1d_array(
                array = action_list, 
                target_length = self.time_horizon, 

                dtype = self.dtype, 
                device = self.device
            )
        elif self.ver == "numpy":
            from util.numpy_util import tsfm_to_1d_array

            action_list = tsfm_to_1d_array(
                array = action_list, 
                target_length = self.time_horizon, 

                dtype = self.dtype
            )

        action_list = self.action_space.clamp(action_list)

        if self.ver == "torch":
            state = init_state.clone()
        elif self.ver == "numpy":
            state = init_state.copy()

        state = self.state_space.clamp(state)
        
        state_list = [state]
        reward_list = []

        for action in action_list:
            intermediate_reward = self.cal_intermediate_reward(
                state = state, 
                action = action
            )

            state = self.cal_dynamics(
                state = state, 
                action = action
            )
            state = self.state_space.clamp(state)

            state_list.append(state)
            reward_list.append(intermediate_reward)

        final_reward = self.cal_final_reward(state = state)
        reward_list.append(final_reward)

        nominal_trajectory = Trajectory(
            state_list = state_list, 
            action_list = action_list, 
            reward_list = reward_list, 

            reward_shape = self.reward_shape, 

            dtype = self.dtype, 

            ver = self.ver
        )

        # `get_nominal_trajectory()` done
        return nominal_trajectory
