import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from gym.spaces import Box

from core.networks.actors import ContinuousDeterministicActor
from core.networks.critics import ContinuousQNetwork
from core.rl_agents.off_policy_algo import OffPolicyAlgorithm
from core.networks.base_actor import BaseActor
from core.utils.mpi_utils import sync_grads


class TD3(OffPolicyAlgorithm):
    def __init__(self, obs_space, action_space, config):

        assert isinstance(action_space, Box), 'action space is not Continuous'
        observation_dim = obs_space.shape[0]
        action_dim = action_space.shape[0]
        max_action = action_space.high[0]

        self.actor = ContinuousDeterministicActor(observation_dim, action_dim, max_action, config.expl_noise, layers_dim=config.layers_dim).to(config.device)
        self.actor_target = ContinuousDeterministicActor(observation_dim, action_dim, max_action, config.expl_noise, layers_dim=config.layers_dim).to(config.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        self.critic = ContinuousQNetwork(observation_dim, action_dim, layers_dim=config.layers_dim).to(config.device)
        self.critic_target = ContinuousQNetwork(observation_dim, action_dim, layers_dim=config.layers_dim).to(config.device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

        super(TD3, self).__init__(actor=self.actor, critic=self.critic, device=config.device, name='TD3')

        self.max_action = max_action
        self.action_dim = action_dim

        self.discount = config.discount
        self.tau = config.tau
        self.policy_noise = config.policy_noise
        self.noise_clip = config.noise_clip
        self.policy_freq = config.policy_freq
        self.expl_noise = config.expl_noise

        self.iteration = 0
        self.iteration_critic = 0

    def reset_optimizer(self, lr=3e-4):
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)

    def update_critic(self, batch, sync_grads_bool=False, comm=None):
        if batch is None:
            return
        # Sample replay buffer
        reward = torch.FloatTensor(batch['rewards']).view(-1, 1).to(self.device)
        done = torch.FloatTensor(1 - batch['terminals']).view(-1, 1).to(self.device)
        observation = torch.FloatTensor(batch['observations']).to(self.device)
        action = torch.FloatTensor(batch['actions']).to(self.device)
        new_observation = torch.FloatTensor(batch['new_observations']).to(self.device)

        # Select action according to policy and add clipped noise
        noise = torch.FloatTensor(batch['actions']).data.normal_(0, self.policy_noise).to(self.device)
        noise = noise.clamp(-self.noise_clip, self.noise_clip)
        next_action = (self.actor_target(new_observation) + noise).clamp(-self.max_action, self.max_action)

        # Compute the target Q value
        target_Q1, target_Q2 = self.critic_target(new_observation, next_action)
        target_Q = torch.min(target_Q1, target_Q2)
        target_Q = reward + (done * self.discount * target_Q).detach()

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(observation, action)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        if sync_grads_bool:
            sync_grads(self.critic, comm)
        self.critic_optimizer.step()

        # Update the frozen target models
        if self.iteration_critic % self.policy_freq == 0:
            self.soft_update_from_to(self.critic, self.critic_target, self.tau)

        self.iteration_critic += 1

    def update_actor(self, batch):
        if batch is None:
            return
        observation = torch.FloatTensor(batch['observations']).to(self.device)
        # Compute actor loss
        actor_loss = -self.critic.Q1(observation, self.actor(observation)).mean()

        # Optimize the actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.soft_update_from_to(self.actor, self.actor_target, self.tau)

    def train_on_batch(self, batch, sync_grads_bool=False):
        # Increment iterations counter
        self.iteration += 1

        self.update_critic(batch, sync_grads_bool)

        # Delayed policy updates
        if self.iteration % self.policy_freq == 0:
            self.update_actor(batch)
