import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from goalsrl.TD3 import utils
from goalsrl.reimplementation.networks import Flatten, MultiInputNetwork, CNNHead

import rlutil.torch.pytorch_util as ptu

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = ptu.default_device()

# Implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3)
# Paper: https://arxiv.org/abs/1802.09477

class Actor(nn.Module):
    def __init__(self, state_dim, goal_dim, action_dim, max_action, state_embedding=None, goal_embedding=None, detach_embeddings=False):
        super(Actor, self).__init__()

        if state_embedding is None:
            state_embedding = Flatten()
        self.state_embedding = state_embedding
        if goal_embedding is None:
            goal_embedding = Flatten()
        self.goal_embedding = goal_embedding
        self.detach_embeddings = detach_embeddings

        self.l1 = nn.Linear(state_dim + goal_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, action_dim)
        
        self.max_action = max_action

    def preprocess(self, x, g):
        # print(x, self.state_embedding(x))
        x = self.state_embedding(x)
        g = self.goal_embedding(g)
        if self.detach_embeddings:
            x, g = x.detach(), g.detach()
        return x, g

    def forward(self, x, g):
        x, g = self.preprocess(x, g)
        x = torch.cat((x, g), 1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.max_action * torch.tanh(self.l3(x)) 
        return x

class Critic(nn.Module):
    def __init__(self, state_dim, goal_dim, action_dim, state_embedding=None, goal_embedding=None, detach_embeddings=False):
        super(Critic, self).__init__()

        if state_embedding is None:
            state_embedding = Flatten()
        self.state_embedding = state_embedding
        if goal_embedding is None:
            goal_embedding = Flatten()
        self.goal_embedding = goal_embedding
        self.detach_embeddings = detach_embeddings

        # Q1 architecture
        self.l1 = nn.Linear(state_dim + goal_dim + action_dim, 400)
        self.l2 = nn.Linear(400, 300)
        self.l3 = nn.Linear(300, 1)

        # Q2 architecture
        self.l4 = nn.Linear(state_dim + goal_dim + action_dim, 400)
        self.l5 = nn.Linear(400, 300)
        self.l6 = nn.Linear(300, 1)
    
    def preprocess(self, x, g):
        x = self.state_embedding(x)
        g = self.goal_embedding(g)
        if self.detach_embeddings:
            x, g = x.detach(), g.detach()
        return x, g

    def forward(self, x, g, u):
        x, g = self.preprocess(x, g)
        xu = torch.cat([x, g, u], 1)

        x1 = F.relu(self.l1(xu))
        x1 = F.relu(self.l2(x1))
        x1 = self.l3(x1)

        x2 = F.relu(self.l4(xu))
        x2 = F.relu(self.l5(x2))
        x2 = self.l6(x2)
        return x1, x2


    def Q1(self, x, g, u):
        x, g = self.preprocess(x, g)
        xu = torch.cat([x, g, u], 1)

        x1 = F.relu(self.l1(xu))
        x1 = F.relu(self.l2(x1))
        x1 = self.l3(x1)
        return x1 

class TD3(object):
    def __init__(self, env, actor_kwargs=None, critic_kwargs=None, lr=1e-3):
        if actor_kwargs is None:
            actor_kwargs = dict()
        if critic_kwargs is None:
            critic_kwargs = dict()
        global device
        device = ptu.default_device()

        state_embedding_fn = lambda: None
        goal_embedding_fn = lambda: None
        
        if len(env.observation_space.shape) > 1: # Images
            state_dim = 64
            goal_dim = 64
            imsize = env.observation_space.shape[1]
            state_embedding_fn = lambda: CNNHead(imsize, spatial_softmax=True, output_size=64)
            goal_embedding_fn = lambda: CNNHead(imsize, spatial_softmax=True, output_size=64)
        else:
            state_dim = np.prod(env.observation_space.shape)
            goal_dim = np.prod(env.goal_space.shape)
            
        action_dim = env.action_space.shape[0]
        max_action = env.action_space.high[0]

        self.actor = Actor(state_dim, goal_dim, action_dim, max_action, state_embedding_fn(), goal_embedding_fn()).to(device)
        self.actor_target = Actor(state_dim, goal_dim, action_dim, max_action, state_embedding_fn(), goal_embedding_fn()).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)

        
        self.critic = Critic(state_dim, goal_dim, action_dim, state_embedding_fn(), goal_embedding_fn()).to(device)
        self.critic_target = Critic(state_dim, goal_dim, action_dim, state_embedding_fn(), goal_embedding_fn()).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=lr)

        self.max_action = env.action_space.high[0]


    def select_action(self, state, goal):
        state = torch.FloatTensor(state).to(device)[None]
        goal = torch.FloatTensor(goal).to(device)[None]

        return self.actor(state, goal).cpu().data.numpy().flatten()


    def train(self, replay_buffer, iterations, batch_size=100, discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2):
        import tqdm
        with tqdm.trange(iterations, leave=False) as ranger:
            for it in ranger:

                # Sample replay buffer 
                x, y, u, g, r, d = replay_buffer.sample(batch_size)
                state = torch.FloatTensor(x).to(device)
                action = torch.FloatTensor(u).to(device)
                next_state = torch.FloatTensor(y).to(device)
                goal = torch.FloatTensor(g).to(device)
                done = torch.FloatTensor(1 - d).to(device)
                reward = torch.FloatTensor(r).to(device)

                # Select action according to policy and add clipped noise 
                noise = torch.FloatTensor(u).data.normal_(0, policy_noise).to(device)
                noise = noise.clamp(-noise_clip, noise_clip)
                next_action = (self.actor_target(next_state, goal) + noise).clamp(-self.max_action, self.max_action)

                # Compute the target Q value
                target_Q1, target_Q2 = self.critic_target(next_state, goal, next_action)
                target_Q = torch.min(target_Q1, target_Q2)
                target_Q = reward + (done * discount * target_Q).detach()

                # Get current Q estimates
                current_Q1, current_Q2 = self.critic(state, goal, action)

                # Compute critic loss
                critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 

                # Optimize the critic
                self.critic_optimizer.zero_grad()
                critic_loss.backward()
                self.critic_optimizer.step()

                # Delayed policy updates
                if it % policy_freq == 0:

                    # Compute actor loss
                    actor_loss = -self.critic.Q1(state, goal, self.actor(state, goal)).mean()

                    # Optimize the actor 
                    self.actor_optimizer.zero_grad()
                    actor_loss.backward()
                    self.actor_optimizer.step()
                    ranger.set_description('Critic loss: %f Actor loss: %f'% (ptu.to_numpy(critic_loss), ptu.to_numpy(actor_loss)))

                    # Update the frozen target models
                    for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

                    for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)


    def save(self, filename, directory):
        torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
        torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))


    def load(self, filename, directory):
        self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
        self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))
