from util.logger import logger

from typing import Optional, Union, Tuple

import numpy as np

import torch

from util.basic_util import get_attr

from OT_MCTS.src.markov_decision_process.space import Space


class ActionSpaceEta(Space):
    def __init__(
        self, 

        eta_low: Optional[float] = 0.0, 
        eta_high: Optional[float] = 1.0, 

        shape: Optional[Tuple] = (1, ), 
        
        dtype: Optional[str] = "float32", 
        device: Optional[str] = "cpu", 

        ver: Optional[str] = "torch"
    ):
        self.ver = ver

        self.dtype = dtype
        if isinstance(self.dtype, str):
            self.dtype = get_attr(self.ver, self.dtype)

        self.device = device

        super().__init__()

        self.eta_low = eta_low
        self.eta_high = eta_high

        self.shape = shape

        # `__init__()` done
        pass


    def clamp(
        self, 
        
        var: torch.Tensor
    ) -> torch.Tensor:
        """
        Func:
            Clamp `var` into the action space. 
        """

        var = var.clone()

        var = torch.clip(
            var, 
            self.eta_low, self.eta_high
        )

        # `clamp()` done
        return var


    def batch_clamp(
        self, 
        
        var_list: torch.Tensor
    ) -> torch.Tensor:
        """
        Func:
            Batch clamp `var_list` into the action space. 
        """

        var_list = var_list.clone()

        var_list = torch.clip(
            var_list, 
            self.eta_low, self.eta_high
        )

        # `clamp()` done
        return var_list


    def get_default_element(
        self
    ) -> torch.Tensor:
        """
        Func:
            Get a default action from the action space. 
        """

        default_action = torch.full(
            size = self.shape, 
            fill_value = (self.eta_low + self.eta_high) / 2, 

            dtype = self.dtype, 
            device = self.device
        )

        default_action = self.clamp(default_action)

        # `get_default_element()` done
        return default_action


    def sample_uniform_element(
        self
    ) -> torch.Tensor:
        """
        Func:
            Sample an action from the space uniformly. 
        """

        action = torch.rand(
            size = self.shape, 

            dtype = self.dtype, 
            device = self.device
        )

        action = self.clamp(action)

        # `sample_uniform_element()` done
        return action
