from util.logger import logger

from typing import Optional, Union, Tuple

import numpy as np

import torch

from util.basic_util import get_attr

from OT_MCTS.src.markov_decision_process.space import Space


class StateSpace(Space):
    def __init__(
        self, 
        
        shape: Optional[Tuple] = (1, ), 

        dtype: Optional[str] = "float32", 
        device: Optional[str] = "cpu", 

        ver: Optional[str] = "torch"
    ):
        self.ver = ver

        self.dtype = dtype
        if isinstance(self.dtype, str):
            self.dtype = get_attr(self.ver, self.dtype)

        self.device = device

        super().__init__()

        self.shape = shape

        # `__init__()` done
        pass


    def clamp(
        self, 
        
        var: torch.Tensor
    ) -> torch.Tensor:
        """
        Func:
            Clamp `var` into the state space. 
        """

        var = var.clone()

        # `clamp()` done
        return var


    def batch_clamp(
        self, 
        
        var_list: torch.Tensor
    ) -> torch.Tensor:
        """
        Func:
            Batch clamp `var_list` into the state space. 
        """

        var_list = var_list.clone()

        # `clamp()` done
        return var_list
