# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Encoder component for the agent."""

from copy import deepcopy
from typing import List, cast

import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder_Decoder(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Encoder_Decoder, self).__init__()

        hidden_dim = 256
        hidden_dim_rew = 256

        self.e1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.e2 = nn.Linear(hidden_dim, hidden_dim)

        self.r1 = nn.Linear(hidden_dim_rew, 1, bias=False)

        self.a1 = nn.Linear(hidden_dim, hidden_dim)
        self.a2 = nn.Linear(hidden_dim, action_dim)

        self.d1 = nn.Linear(hidden_dim, hidden_dim)
        self.d2 = nn.Linear(hidden_dim, state_dim)

    def forward(self, state, action):
        l = F.relu(self.e1(torch.cat([state, action], 1)))
        l = F.relu(self.e2(l))

        r = self.r1(l)

        d = F.relu(self.d1(l))
        ns = self.d2(d)

        d = F.relu(self.a1(l))
        a = self.a2(d)

        return ns, r, a, l

    def latent(self, state, action):
        l = F.relu(self.e1(torch.cat([state, action], 1)))
        l = F.relu(self.e2(l))
        return l


# Vanilla Variational Auto-Encoder
class VAE(nn.Module):
    def __init__(self, state_dim, action_dim, latent_dim, device):
        super(VAE, self).__init__()
        self.e1 = nn.Linear(state_dim, 750)
        self.e2 = nn.Linear(750, 750)

        self.mean = nn.Linear(750, latent_dim)
        self.log_std = nn.Linear(750, latent_dim)

        self.d1 = nn.Linear(state_dim + latent_dim, 750)
        self.d2 = nn.Linear(750, 750)
        self.d3 = nn.Linear(750, state_dim)

        self.latent_dim = latent_dim
        self.device = device

    def forward(self, state_action):
        z = F.relu(self.e1(state_action))
        z = F.relu(self.e2(z))

        mean = self.mean(z)
        # Clamped for numerical stability
        log_std = self.log_std(z).clamp(-4, 15)
        std = torch.exp(log_std)
        z = mean + std * torch.randn_like(std)

        u = self.decode(state_action, z)

        return u, mean, std

    def decode(self, state, z=None):
        # When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
        if z is None:
            z = torch.randn(
                (state.shape[0],
                 self.latent_dim)).to(self.device).clamp(-0.5, 0.5)

        result = F.relu(self.d1(torch.cat([state, z], 1)))
        result = F.relu(self.d2(result))
        return torch.tanh(self.d3(result))
