from argparse import ArgumentTypeError

from torch.nn.utils.rnn import pad_sequence

from .imports import *

seed = 0    # Global variable to store the seed

def set_torch(n_threads: int = 0, deterministic: bool = True, cuda: bool = False) -> th.device:
    """Configure PyTorch settings including the number of threads, determinism, and CUDA usage.

    Args:
        n_threads: Number of threads for PyTorch operations.
        deterministic: Whether to use deterministic algorithms.
        cuda: Whether to enable CUDA.

    Returns:
        The PyTorch device configured based on availability and input flags.
    """
    th.set_num_threads(n_threads)
    th.backends.cudnn.deterministic = deterministic
    return th.device("cuda" if th.cuda.is_available() and cuda else "cpu")

def set_random_seed(s: Optional[int] = None) -> None:
    """Set the random seed for reproducibility across different libraries.

    Args:
        s: Seed value to set. If None, the global seed value is used.
    """
    global seed
    if s is not None: seed = s
    rnd.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)
    th.cuda.manual_seed_all(seed)

def str2bool(s: str) -> bool:
    """Convert a string representation of a boolean to an actual boolean value.

    Args:
        s: String to convert.

    Returns:
        The boolean value corresponding to the input string.

    Raises:
        ArgumentTypeError: If the string does not represent a boolean value.
    """
    if s.lower() == 'true':return True
    elif s.lower() == 'false': return False
    raise ArgumentTypeError('Boolean value expected.')

def to_list_tensor(data, device, astype=th.float32, add_dim=True):
    if add_dim:
        # each data element is (n_envs, obs_size) -> cast it to (n_envs, 1, obs_size) for forward passes
        return [th.as_tensor(d, dtype=astype).unsqueeze(1).to(device) for d in data]
    return [th.as_tensor(d, dtype=astype).to(device) for d in data]

def to_cat_3dtensor(data, device, astype=th.float32):
    return th.as_tensor(np.concatenate(data, axis=-1)).unsqueeze(1).to(device)

def to_agent_shape(data, device, type):
    return th.tensor(np.stack(data, axis=-1), dtype=type).to(device)

def _get_pad_and_mask_from_obs(obs):
    zeropad_obs = pad_sequence(obs, padding_value=th.tensor(float('0')), batch_first=True)
    nanpad_obs = pad_sequence(obs, padding_value=th.tensor(float('nan')), batch_first=True)
   
    return zeropad_obs, ~th.isnan(nanpad_obs).any(-1)

def linear_schedule(start: float, end: float, duration: int, t: int, eps=True) -> float:
    slope = (end - start) / duration
    if eps: return max(slope * t + start, end)
    return min(slope * t + start, end)

def Linear(input_dim: int, output_dim: int, act_fn: str = 'leaky_relu', init_weight_uniform: bool = True) -> nn.Linear:
    """Create and initialize a linear layer with appropriate weights.
    https://machinelearningmastery.com/weight-initialization-for-deep-learning-neural-networks/    
    
    Args:
        input_dim: Input dimension.
        output_dim: Output dimension.
        act_fn: Activation function.
        init_weight_uniform: Whether to uniformly sample initial weights.

    Returns:
        The initialized layer.
    """
    act_fn = act_fn.lower()
    gain = nn.init.calculate_gain(act_fn)
    layer = nn.Linear(input_dim, output_dim)
    if init_weight_uniform: nn.init.xavier_uniform_(layer.weight, gain=gain)
    else: nn.init.xavier_normal_(layer.weight, gain=gain)
    nn.init.constant_(layer.bias, 0.00)
    return layer

class Tanh(nn.Module):
    """In-place tanh module."""

    def forward(self, input: th.Tensor) -> th.Tensor:
        """Forward pass for the in-place tanh activation function.

        Args:
            input (th.Tensor): Input tensor.

        Returns:
            Output tensor after applying in-place tanh.
        """
        return th.tanh_(input)

# Useful for picking up the right activation in the networks
th_act_fns = {
    'tanh': Tanh(),
    'relu': nn.ReLU(),
    'leaky_relu': nn.LeakyReLU()
}

