import numpy as np
import torch
import os
import random
from torch import nn


class ModeManager:
    """A context manager for temporarily setting models to eval or train mode."""

    def __init__(self, mode, *models):
        self.mode = mode
        self.models = models

    def __enter__(self):
        self.prev_states = [model.training for model in self.models]
        for model in self.models:
            model.train(self.mode == "train")

    def __exit__(self, *args):
        for model, prev_state in zip(self.models, self.prev_states):
            model.train(prev_state)
        return False


# Using ModeManager for both eval_mode and train_mode
def eval_mode(*models):
    return ModeManager('eval', *models)


def train_mode(*models):
    return ModeManager('train', *models)


def soft_update_params(net, target_net, tau):
    """Soft update of parameters using the tau coefficient."""
    with torch.no_grad():
        for param, target_param in zip(net.parameters(), target_net.parameters()):
            target_param.data.copy_(tau * param + (1 - tau) * target_param)


def set_seed_everywhere(seed):
    """Set random seed for reproducibility."""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def sample_action(env):
    low = env.action_space.low
    high = env.action_space.high
    return np.random.uniform(low, high)


def make_dir(*path_parts):
    """Create a directory if it doesn't exist, handle any potential exceptions."""
    dir_path = os.path.join(*path_parts)
    try:
        os.makedirs(dir_path, exist_ok=True)  # Handle recursive directory creation
    except OSError as e:
        print(f"Error creating directory {dir_path}: {e}")
    return dir_path


def weight_init(m):
    """Custom weight initialization for Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)


class MLP(nn.Module):
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 output_dim,
                 hidden_depth,
                 output_mod=None):
        super().__init__()
        self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth,
                         output_mod)
        self.apply(weight_init)

    def forward(self, x):
        return self.trunk(x)


def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None):
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
        for i in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
        mods.append(nn.Linear(hidden_dim, output_dim))
    if output_mod is not None:
        mods.append(output_mod)
    trunk = nn.Sequential(*mods)
    return trunk


def to_np(tensor):
    """Convert tensor to numpy array."""
    if tensor is None:
        return None
    elif tensor.nelement() == 0:
        return np.array([])
    return tensor.cpu().detach().numpy()
