from util.logger import logger

from typing import Optional, Union, Tuple

import numpy as np

import torch

from util.basic_util import get_attr

from OT_MCTS.src.markov_decision_process.space import Space


class ActionSpaceEps(Space):
    def __init__(
        self, 

        eps_seed_low: Optional[int] = 3072, 
        eps_seed_high: Optional[int] = 4095, 

        shape: Optional[Tuple] = (1, ), 
        
        dtype: Optional[str] = "float32", 
        device: Optional[str] = "cpu", 

        ver: Optional[str] = "torch"
    ):
        self.ver = ver

        self.dtype = dtype
        if isinstance(self.dtype, str):
            self.dtype = get_attr(self.ver, self.dtype)

        self.device = device

        super().__init__()

        self.eps_seed_low = eps_seed_low
        self.eps_seed_high = eps_seed_high

        self.shape = shape

        # `__init__()` done
        pass


    def clamp(
        self, 
        
        var: torch.Tensor
    ) -> torch.Tensor:
        """
        Func:
            Clamp `var` into the action space. 
        """

        var = var.clone()

        # `clamp()` done
        return var


    def batch_clamp(
        self, 
        
        var_list: torch.Tensor
    ) -> torch.Tensor:
        """
        Func:
            Batch clamp `var_list` into the action space. 
        """

        var_list = var_list.clone()

        # `clamp()` done
        return var_list


    def get_default_element(
        self
    ) -> torch.Tensor:
        """
        Func:
            Get a default action from the action space. 
        """

        default_action = torch.randint(
            low = self.eps_seed_low, 
            high = self.eps_seed_high + 1, 

            size = self.shape, 

            dtype = self.dtype, 
            device = self.device
        )

        default_action = self.clamp(default_action)

        # `get_default_element()` done
        return default_action


    def sample_uniform_element(
        self
    ) -> torch.Tensor:
        """
        Func:
            Sample an action from the space uniformly. 
        """

        action = torch.randint(
            low = self.eps_seed_low, 
            high = self.eps_seed_high, 

            size = self.shape, 

            dtype = self.dtype, 
            device = self.device
        )

        action = self.clamp(action)

        # `sample_uniform_element()` done
        return action
