import sys

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
from models import model_factory
from models import graph_net


def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)

def soft_update_list(targets, sources, tau):
    for target, source in zip(targets, sources):
        soft_update(target, source, tau)

def hard_update_list(targets, sources):
    for target, source in zip(targets, sources):
        hard_update(target, source)
        
def adjust_lr(optimizer, init_lr, episode_i, num_episode, start_episode):
    if episode_i < start_episode:
        return init_lr
    lr = init_lr * (1 - (episode_i - start_episode) / (num_episode - start_episode))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


class LayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps

        if self.affine:
            self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
            self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        shape = [-1] + [1] * (x.dim() - 1)
        mean = x.view(x.size(0), -1).mean(1).view(*shape)
        std = x.view(x.size(0), -1).std(1).view(*shape)

        y = (x - mean) / (std + self.eps)
        if self.affine:
            shape = [1, -1] + [1] * (x.dim() - 2)
            y = self.gamma.view(*shape) * y + self.beta.view(*shape)
        return y


nn.LayerNorm = LayerNorm

class Actor(nn.Module):
    def __init__(self, hidden_size, num_inputs, num_outputs):
        super(Actor, self).__init__()
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.mu = nn.Linear(hidden_size, num_outputs)
        self.mu.weight.data.mul_(0.1)
        self.mu.bias.data.mul_(0.1)

    def forward(self, inputs):
        x = inputs
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        mu = self.mu(x)
        return mu

class Critic(nn.Module):
    def __init__(self, hidden_size, num_inputs, num_outputs, critic_type, agent_id):
        super(Critic, self).__init__()

        self.agent_id = agent_id
        self.critic_type = critic_type
        sa_dim = num_inputs + num_outputs
        self.net_fn = model_factory.get_model_fn(critic_type)
        self.net = self.net_fn(sa_dim, 1, hidden_size)

    def forward(self, inputs, actions):
        bz = inputs.size()[0]
        s_n = inputs.view(bz, 1, -1)
        a_n = actions.view(bz, 1, -1)
        x = torch.cat((s_n, a_n), dim=2)
        V = self.net(x)
        return V


class DDPG(object):
    def __init__(self, update_type, gamma, tau, hidden_size, obs_dim, n_action, n_agent, obs_dims, actor_lr, critic_lr,
                 fixed_lr, critic_type, actor_type, train_noise, num_episodes, num_steps,
                 critic_dec_cen, target_update_mode='soft', device='cpu'):
        self.device = device
        self.update_type = update_type
        self.obs_dim = obs_dim
        self.n_agent = n_agent
        self.n_action = n_action
        
        self.actor_list = [Actor(hidden_size, obs_dim, n_action).to(self.device) for i in range(self.n_agent)]
        self.actor_target_list = [Actor(hidden_size, obs_dim, n_action).to(self.device) for i in range(self.n_agent)]
        self.actor_perturbed_list = [Actor(hidden_size, obs_dim, n_action) for i in range(self.n_agent)]
        self.actor_optim_list = [Adam(actor.parameters(),
                                lr=actor_lr, weight_decay=0) for actor in self.actor_list]

        self.critic_list = [Critic(hidden_size, obs_dims[1],
                             n_action, critic_type, i).to(self.device) for i in range(self.n_agent)]
        self.critic_target_list = [Critic(hidden_size,
            obs_dims[1], n_action, critic_type, i).to(self.device) for i in range(self.n_agent)]
        self.critic_optim_list = [Adam(list(critic.parameters()), lr=critic_lr) for critic in self.critic_list]
                
        self.fixed_lr = fixed_lr
        self.init_act_lr = actor_lr
        self.init_critic_lr = critic_lr
        self.num_episodes = num_episodes
        self.start_episode = 0
        self.num_steps = num_steps
        self.gamma = gamma
        self.tau = tau
        self.train_noise = train_noise
        self.obs_dims_cumsum = np.cumsum(obs_dims)
        self.critic_dec_cen = critic_dec_cen
        self.debug = False
        self.target_update_mode = target_update_mode
        
        self.actor_params_list = [actor.parameters() for actor in self.actor_list]
        self.critic_params_list = [list(critic.parameters()) for critic in self.critic_list]
        hard_update_list(self.actor_target_list, self.actor_list)
        hard_update_list(self.critic_target_list, self.critic_list)

    def adjust_lr(self, i_episode):
        for i in range(self.n_agent):
            adjust_lr(self.actor_optim_list[i], self.init_act_lr, i_episode, self.num_episodes, self.start_episode)
            adjust_lr(self.critic_optim_list[i], self.init_critic_lr, i_episode, self.num_episodes, self.start_episode)
            
    def lambda1(self, step):
        start_decrease_step = ((self.num_episodes / 2)
                               * self.num_steps) / 100
        max_step = (self.num_episodes * self.num_steps) / 100
        return 1 - ((step - start_decrease_step) / (
                max_step - start_decrease_step)) if step > start_decrease_step else 1

    def select_action(self, agent_id, state, action_noise=None, param_noise=False, grad=False):
        self.actor_list[agent_id].eval()
        if param_noise:
            mu = self.actor_perturbed_list[agent_id]((Variable(state)))
        else:
            mu = self.actor_list[agent_id]((Variable(state)))

        self.actor_list[agent_id].train()
        if not grad:
            mu = mu.data

        if action_noise:
            noise = np.log(-np.log(np.random.uniform(0, 1, mu.size())))
            try:
                mu -= torch.Tensor(noise).to(self.device)
            except (AttributeError, AssertionError):
                mu -= torch.Tensor(noise)

        action = F.softmax(mu, dim=1)
        if not grad:
            return action
        else:
            return action, mu
    
    def update_critic_parameters(self, agent_id, batch, shuffle=None, eval=False):
        state_batch = Variable(torch.cat(batch.state)).to(self.device)
        action_batch = Variable(torch.cat(batch.action)).to(self.device)
        reward_batch = Variable(torch.cat(batch.reward)).to(self.device)
        mask_batch = Variable(torch.cat(batch.mask)).to(self.device)
        next_state_batch = torch.cat(batch.next_state).to(self.device)
        
        batch_size = len(state_batch)
        
        state_batch = state_batch.view(batch_size, self.n_agent, -1)
        action_batch = action_batch.view(batch_size, self.n_agent, -1)
        next_state_batch = next_state_batch.view(batch_size, self.n_agent, -1)
        
        next_action_batch = self.select_action(
                agent_id, next_state_batch[:, agent_id, :], action_noise=self.train_noise)
        
        self.critic_optim_list[agent_id].zero_grad()
        
        next_state_action_values = self.critic_target_list[agent_id](
                next_state_batch[:, agent_id, :], next_action_batch)
        
        reward_batch = reward_batch[:, agent_id].unsqueeze(1)
        mask_batch = mask_batch[:, agent_id].unsqueeze(1)
        expected_state_action_batch = reward_batch + (self.gamma * mask_batch * next_state_action_values)
        state_action_batch = self.critic_list[agent_id](state_batch[:, agent_id, :], action_batch[:, agent_id, :])
        perturb_out = 0
        value_loss = ((state_action_batch - expected_state_action_batch) ** 2).mean()
            
        if eval:
            return value_loss.item(), perturb_out
        value_loss.backward()
        unclipped_norm = clip_grad_norm_(self.critic_params_list[agent_id], 0.5)
        self.critic_optim_list[agent_id].step()
        if self.target_update_mode == 'soft':
            soft_update(self.critic_target_list[agent_id], self.critic_list[agent_id], self.tau)
        elif self.target_update_mode == 'hard':
            hard_update(self.critic_target_list[agent_id], self.critic_list[agent_id])
        return value_loss.item(), perturb_out, unclipped_norm

    def update_actor_parameters(self, update_agent_id, batch, shuffle=None):
        state_batch = Variable(torch.cat(batch.state)).to(self.device)
        batch_size = len(state_batch)
        state_batch = state_batch.view(batch_size, self.n_agent, -1)
        
        self.actor_optim_list[update_agent_id].zero_grad()
        action_batch_n, logit = self.select_action(update_agent_id, state_batch[:, update_agent_id, :], action_noise=self.train_noise, grad=True)
        
        policy_loss = -self.critic_list[update_agent_id](state_batch[:, update_agent_id, :], action_batch_n)
     
        policy_loss = policy_loss.mean() + 1e-3 * (logit ** 2).mean()
        policy_loss.backward()
        clip_grad_norm_(self.actor_params_list[update_agent_id], 0.5)
        self.actor_optim_list[update_agent_id].step()
        
        soft_update(self.actor_target_list[update_agent_id], self.actor_list[update_agent_id], self.tau)
        soft_update(self.critic_target_list[update_agent_id], self.critic_list[update_agent_id], self.tau)

        return policy_loss.item()


    def perturb_actor_parameters(self, param_noise):
        """Apply parameter noise to actor model, for exploration"""
        hard_update(self.actor_perturbed, self.actor)
        params = self.actor_perturbed.state_dict()
        for name in params:
            if 'ln' in name:
                pass
            param = params[name]
            param += torch.randn(param.shape) * param_noise.current_stddev

    def save_model(self, env_name, suffix="", actor_path=None, critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/ddpg_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/ddpg_critic_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.actor.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)

    def load_model(self, actor_path, critic_path):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.actor.load_state_dict(torch.load(actor_path))
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))

    @property
    def actor_lr(self):
        return self.actor_optim.param_groups[0]['lr']
