# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import numpy as np
import torch
from torch import nn
from torch import distributions as pyd
import torch.nn.functional as F
import gym
import os
from collections import deque
import random
import math
import re

import dmc2gym


def make_env(cfg, task_id=0):
    """Helper function to create dm_control environment"""
    if cfg.env == 'ball_in_cup_catch':
        domain_name = 'ball_in_cup'
        task_name = 'catch'
    else:
        domain_name = cfg.env.split('_')[0]
        task_name = '_'.join(cfg.env.split('_')[1:])

    environment_kwargs = {}
    environment_kwargs = {'task_id': task_id}
    env = dmc2gym.make(domain_name=domain_name,
                       task_name=task_name,
                       seed=cfg.seed,
                       visualize_reward=True,
                       environment_kwargs=environment_kwargs)
    env.seed(cfg.seed)
    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    return env


class eval_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


class train_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(True)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


def soft_update_params(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data +
                                (1 - tau) * target_param.data)


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def make_dir(*path_parts):
    dir_path = os.path.join(*path_parts)
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path


def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0.0)


class MLP(nn.Module):
    def __init__(self,
                 input_dim,
                 hidden_dim,
                 output_dim,
                 hidden_depth,
                 output_mod=None):
        super().__init__()
        self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth,
                         output_mod)
        self.apply(weight_init)

    def forward(self, x):
        return self.trunk(x)


def mlp(input_dim,
        hidden_dim,
        output_dim,
        hidden_depth,
        output_mod=None,
        weight_norm=False,
        custome=False):
    if hidden_depth == 0:
        if weight_norm:
            mods = [
                torch.nn.utils.weight_norm(nn.Linear(input_dim, output_dim))
            ]
        else:
            mods = [nn.Linear(input_dim, output_dim)]

    else:
        if custome:
            mods = [nn.Linear(input_dim, 256), nn.ReLU(inplace=True)]
            mods += [nn.Linear(256, 256), nn.ReLU(inplace=True)]
            mods += [nn.Linear(256, 1024), nn.ReLU(inplace=True)]
            #mods.append(torch.nn.utils.weight_norm(nn.Linear(1024, output_dim)))
            mods.append(nn.Linear(1024, output_dim))
        else:
            mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
            for i in range(hidden_depth - 1):
                mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
            mods.append(nn.Linear(hidden_dim, output_dim))
    if output_mod is not None:
        mods.append(output_mod)
    trunk = nn.Sequential(*mods)
    return trunk


class PhiWJointMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim_1, hidden_dim_2, output_dim):
        super().__init__()

        self.phi_l1 = nn.Linear(input_dim, hidden_dim_1)
        self.phi_l2 = nn.Linear(hidden_dim_1, hidden_dim_2)
        self.phi_l3 = nn.Linear(hidden_dim_2, input_dim)

        self.W1 = nn.Linear(input_dim, output_dim, bias=False)
        self.W2 = nn.Linear(input_dim, output_dim, bias=False)
        self.W3 = nn.Linear(input_dim, output_dim, bias=False)
        self.W4 = nn.Linear(input_dim, output_dim, bias=False)

        self.apply(weight_init)

    def forward(self, x, env_id=None):
        phi_l1_out = F.relu(self.phi_l1(x))
        phi_l2_out = F.relu(self.phi_l2(phi_l1_out))
        phi_l3_out = self.phi_l3(phi_l2_out)

        W1_out = self.W1(phi_l3_out)
        W2_out = self.W2(phi_l3_out)
        W3_out = self.W3(phi_l3_out)
        W4_out = self.W4(phi_l3_out)

        # Use W1 Branch for task_train 0 and task_eval 4
        W1_mask = torch.logical_or(torch.eq(env_id, 0), torch.eq(env_id, 4))
        # Use W1 Branch for task_train 1 and task_eval 5
        W2_mask = torch.logical_or(torch.eq(env_id, 1), torch.eq(env_id, 5))
        W3_mask = torch.eq(env_id, 2)
        W4_mask = torch.eq(env_id, 3)

        W1_out *= W1_mask
        W2_out *= W2_mask
        W3_out *= W3_mask
        W4_out *= W4_mask

        W_out = W1_out + W2_out + W3_out + W4_out

        return W_out

    def latent_phi(self, x):
        phi_l1_out = F.relu(self.phi_l1(x))
        phi_l2_out = F.relu(self.phi_l2(phi_l1_out))
        phi_l3_out = self.phi_l3(phi_l2_out)
        return phi_l3_out


class PhiWJointVAE(nn.Module):
    def __init__(self, state_dim, hidden_dim_1, hidden_dim_2, latent_dim,
                 output_dim):
        super(PhiWJointVAE, self).__init__()

        self.e1 = nn.Linear(state_dim, hidden_dim_1)
        self.e2 = nn.Linear(hidden_dim_1, hidden_dim_2)

        self.mean = nn.Linear(hidden_dim_2, latent_dim)
        self.log_std = nn.Linear(hidden_dim_2, latent_dim)

        self.d1 = nn.Linear(state_dim + latent_dim, hidden_dim_1)
        self.d2 = nn.Linear(hidden_dim_1, hidden_dim_2)
        self.d3 = nn.Linear(hidden_dim_2, state_dim)

        self.latent_dim = latent_dim

        self.W1 = nn.Linear(state_dim, output_dim, bias=False)
        self.W2 = nn.Linear(state_dim, output_dim, bias=False)
        self.W3 = nn.Linear(state_dim, output_dim, bias=False)
        self.W4 = nn.Linear(state_dim, output_dim, bias=False)

        self.apply(weight_init)

    def forward(self, state, env_id=None):
        z = F.relu(self.e1(state))
        z = F.relu(self.e2(z))

        mean = self.mean(z)
        # Clamped for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(state, z)

        W1_out = self.W1(u)
        W2_out = self.W2(u)
        W3_out = self.W3(u)
        W4_out = self.W4(u)

        # Use W1 Branch for task_train 0 and task_eval 4
        W1_mask = torch.logical_or(torch.eq(env_id, 0), torch.eq(env_id, 4))
        # Use W1 Branch for task_train 1 and task_eval 5
        W2_mask = torch.logical_or(torch.eq(env_id, 1), torch.eq(env_id, 5))
        W3_mask = torch.eq(env_id, 2)
        W4_mask = torch.eq(env_id, 3)

        W1_out *= W1_mask
        W2_out *= W2_mask
        W3_out *= W3_mask
        W4_out *= W4_mask

        W_out = W1_out + W2_out + W3_out + W4_out

        return W_out, mean, std

    def decode(self, state, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            z = torch.randn(
                (state.shape[0],
                 self.latent_dim)).to(self.device).clamp(-0.5, 0.5)

        result = F.relu(self.d1(torch.cat([state, z], 1)))
        result = F.relu(self.d2(result))
        return torch.tanh(self.d3(result))


def to_np(t):
    if t is None:
        return None
    elif t.nelement() == 0:
        return np.array([])
    else:
        return t.cpu().detach().numpy()


def get_latest_file(path):
    files = os.listdir(path)
    max_ckpt = 0
    for ckpt in files:
        step = int(re.findall(r'\d+', ckpt)[0])
        if step > max_ckpt:
            max_ckpt = int(step)
    print('Loading Latest ckpt ', max_ckpt)
    return max_ckpt
