import math
import os
import random
from collections import deque

import dmc2gym
import gym
import numpy as np
import torch
import torch.nn.functional as F
from torch import distributions as pyd
from torch import nn


def make_env(cfg):
    """Helper function to create dm_control environment"""
    if cfg.env == "ball_in_cup_catch":
        domain_name = "ball_in_cup"
        task_name = "catch"
    else:
        domain_name = cfg.env.split("_")[0]
        task_name = "_".join(cfg.env.split("_")[1:])
    if domain_name == "gym":
        env = gym_env(task_name, cfg.seed)
    else:
        env = dmc2gym.make(
            domain_name=domain_name,
            task_name=task_name,
            seed=cfg.seed,
            visualize_reward=True,
        )
    env.seed(cfg.seed)
    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    return env


class eval_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


class train_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(True)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


def soft_update_params(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def make_dir(*path_parts):
    dir_path = os.path.join(*path_parts)
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path


def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m.bias, "data"):
            m.bias.data.fill_(0.0)


class MLP(nn.Module):
    def __init__(
        self, input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None
    ):
        super().__init__()
        self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod)
        self.apply(weight_init)

    def forward(self, x):
        return self.trunk(x)


def mlp(
    input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None, untrainable=0
):
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
        for i in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
        mods.append(nn.Linear(hidden_dim, output_dim))
    if untrainable > 0:
        print(f"Setting {untrainable} layers to not update.")
        i = 0
        for layer in mods:
            if isinstance(layer, nn.Linear):
                if i < untrainable:
                    for param in layer.parameters():
                        param.requires_grad = False
                i += 1
    if output_mod is not None:
        mods.append(output_mod)
    trunk = nn.Sequential(*mods)
    return trunk


def to_np(t):
    if t is None:
        return None
    elif t.nelement() == 0:
        return np.array([])
    else:
        return t.cpu().detach().numpy()


def get_stats(name, x):
    x_metrics = {
        f"{name}_min": x.min(),
        f"{name}_max": x.max(),
        f"{name}_mean": x.mean(),
        f"{name}_std": x.std(),
    }
    return x_metrics


gym_ids = {
    "humanoid": "Humanoid-v2",
    "halfcheetah": "HalfCheetah-v2",
    "hopper": "Hopper-v2",
    "walker": "Walker2d-v2",
}


def gym_env(env_name, seed):
    env = gym.make(gym_ids[env_name])
    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env
