# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from agent import Agent
from torch import linalg as LA


class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()

        self.l1 = nn.Linear(state_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, action_dim)

        self.max_action = max_action

    def forward(self, state):
        a = F.relu(self.l1(state))
        a = F.relu(self.l2(a))
        return self.max_action * torch.tanh(self.l3(a))


class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, output_dim=1):
        super(Critic, self).__init__()

        self.l1 = nn.Linear(state_dim + action_dim, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, output_dim)

    def forward(self, state, action):
        q = F.relu(self.l1(torch.cat([state, action], 1)))
        q = F.relu(self.l2(q))
        return self.l3(q)


class DDPGAgent(Agent):
    """DDPG algorithm."""
    def __init__(self, obs_dim, action_dim, goal_dim, action_range, device,
                 discount, critic_tau, batch_size):

        max_action = action_range[1]
        self.batch_size = batch_size
        self.device = device
        self.actor = Actor(obs_dim, action_dim, max_action).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=3e-4)

        self.use_sf_reward = True
        if self.use_sf_reward:
            q_output_dim = 4
        else:
            q_output_dim = 1

        self.critic = Critic(obs_dim, action_dim, q_output_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=3e-4)

        self.max_action = max_action
        self.discount = discount
        self.tau = critic_tau

        state_features = 4
        self.phi = torch.ones(self.batch_size,
                              state_features,
                              requires_grad=False,
                              device=device)

        self.W = torch.ones(self.batch_size,
                            state_features,
                            requires_grad=False,
                            device=device)

    def act(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
        return self.actor(state).cpu().data.numpy().flatten()

    def update(self, replay_buffer, logger, writer, step):
        # Sample replay buffer
        state, action, reward, next_state, not_done, not_done_no_max, goal = replay_buffer.sample(
            self.batch_size)

        logger.log('train/batch_reward', reward.mean(), step)

        # Compute the target Q value
        target_Q = self.critic_target(next_state,
                                      self.actor_target(next_state))

        self.phi[:, 1:3] = state[:, 0:2]
        self.phi[:, 3:] = ((LA.norm(state[:, 0:2], dim=1, keepdim=True))**2)

        self.W[:, 0:1] = -((LA.norm(goal, dim=1, keepdim=True))**2)
        self.W[:, 1:3] = 2 * goal
        self.W[:, 3:] = -torch.ones_like(self.W[:, 3:])

        if self.use_sf_reward:
            reward = self.phi

        target_Q = reward + (not_done * self.discount * target_Q).detach()

        # Get current Q estimate
        current_Q = self.critic(state, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q, target_Q)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        if self.use_sf_reward:
            psi = self.critic(state, self.actor(state))
            q_value = torch.bmm(psi.view(self.batch_size, 1, 4),
                                self.W.view(self.batch_size, 4, 1)).squeeze(-1)
        else:
            q_value = self.critic(state, self.actor(state))
        # Compute actor loss
        actor_loss = -q_value.mean()

        # Optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Update the frozen target models
        for param, target_param in zip(self.critic.parameters(),
                                       self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.actor.parameters(),
                                       self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data +
                                    (1 - self.tau) * target_param.data)

    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(),
                   filename + "_critic_optimizer")

        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(),
                   filename + "_actor_optimizer")

    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(
            torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(
            torch.load(filename + "_actor_optimizer"))
        self.actor_target = copy.deepcopy(self.actor)
