#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import parl
import torch
import torch.nn as nn
import torch.nn.functional as F
from parl.utils.utils import check_model_method
from copy import deepcopy
from torch.distributions import Normal

__all__ = ['DECMAAC']


class DECMAAC(parl.Algorithm):
    def __init__(self,
                 model,
                 agent_index=None,
                 gamma=None,
                 tau=None,
                 alpha=None,
                 actor_lr=None,
                 critic_lr=None):
        """  MAAC algorithm

        Args:
            model (parl.Model): forward network of actor and critic.
                                The function get_actor_params() of model should be implemented.
            agent_index (int): index of agent, in multiagent env
            act_space (list): action_space, gym space
            gamma (float): discounted factor for reward computation.
            tau (float): decay coefficient when updating the weights of self.target_model with self.model
            critic_lr (float): learning rate of the critic model
            actor_lr (float): learning rate of the actor model
        """
        # checks
        check_model_method(model, 'value', self.__class__.__name__)
        check_model_method(model, 'policy', self.__class__.__name__)
        check_model_method(model, 'get_actor_params', self.__class__.__name__)
        check_model_method(model, 'get_critic_params', self.__class__.__name__)
        assert isinstance(agent_index, int)
        assert isinstance(tau, float)
        assert isinstance(alpha, float)
        assert isinstance(actor_lr, float)
        assert isinstance(critic_lr, float)

        self.agent_index = agent_index
        self.gamma = gamma
        self.tau = tau
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.alpha = alpha  ###

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model = model.to(device)
        self.target_model = deepcopy(model)
        self.sync_target(0)

        self.actor_optimizer = torch.optim.Adam(
            lr=self.actor_lr, params=self.model.get_actor_params())
        self.critic_optimizer = torch.optim.Adam(
            lr=self.critic_lr, params=self.model.get_critic_params())

        # self.actor_scheduler = torch.optim.lr_scheduler.StepLR(
        #     self.actor_optimizer, step_size = 1e4, gamma = 0.9)
        # self.critic_scheduler = torch.optim.lr_scheduler.StepLR(
        #     self.critic_optimizer, step_size = 1e4, gamma = 0.9)
    
    def predict(self, obs, use_target_model=False):
        if use_target_model:
            mean, log_std = self.target_model.policy(obs)
        else:
            mean, log_std = self.model.policy(obs)
        dist = Normal(mean, torch.exp(log_std))
        action = dist.sample()  
        # action = dist.rsample()  #重参数化：使采样过程可微，实现反向传播
        action_apply = F.softmax(action, dim=-1)
        return action, action_apply

    def Q(self, obs_i, act_i, use_target_model=False):
        """ use the value model to predict Q values
        Args: 
            obs_n (list of paddle tensor): all agents' observation, len(agent's num) + shape([B] + shape of obs_n)
            act_n (list of paddle tensor): all agents' action, len(agent's num) + shape([B] + shape of act_n)
            use_target_model (bool): use target_model or not

        Returns:
            Q (paddle tensor): Q value of this agent, shape([B])
        """
        if use_target_model:
            return self.target_model.value(obs_i, act_i)
        else:
            return self.model.value(obs_i, act_i)

    def V(self, obs_n, use_target_model=False):
        """ use the value model to predict Q values
        Args: 
            obs_n (list of paddle tensor): all agents' observation, len(agent's num) + shape([B] + shape of obs_n)

        Returns:
            V (paddle tensor): V value of this agent, shape([B])
        """
        if use_target_model:
            return self.target_model.value(obs_n)
        else:
            return self.model.value(obs_n)

    def learn(self, obs_i, act_i, td, advantage):
        """ update actor and critic model with MADDPG algorithm
        """
        actor_cost = self._actor_learn(obs_i, act_i, advantage)
        critic_cost = self._critic_learn(td)
        self.sync_target()
        return critic_cost, actor_cost

    def _actor_learn(self, obs_i, act_i, advantage):
        action = act_i  #tensor:[batchsize, act_dim]
        # action_i, _ = self.predict(obs_i)  #tensor:[batchsize, act_dim]
        mean, log_std = self.model.policy(obs_i)
        dist = Normal(mean, torch.exp(log_std))  #tensor:[batchsize]
        action_log_prob = dist.log_prob(action).sum(-1)
        act_log_prob = action_log_prob.detach()
        
        # adv = local_target_v - self.V(obs_i)  #1-step TD  #tensor:[batchsize]
        # adv = (local_target_v - self.V(obs_i)).detach()  #TEST
        # adv = advantage.detach()

        adv = advantage.detach()
        
        act_cost = -1.0 * torch.mean(action_log_prob * adv)
        cost = act_cost
        
        # act_cost = -1.0 * (action_log_prob * adv)
        # cost = torch.mean(act_cost + self.alpha * act_log_prob)

        self.actor_optimizer.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(self.model.get_actor_params(), 0.5)
        self.actor_optimizer.step()
        # self.actor_scheduler.step()
        return cost

    def _critic_learn(self, td):
        # pred_v = self.V(obs_i)
        # cost = F.mse_loss(pred_v, local_target_v)  #1-step TD
        # pred_q = self.Q(obs_i, act_i)
        # cost = F.mse_loss(pred_q, target_q)  #1-step TD

        cost = torch.mean(torch.pow(td, 2))

        self.critic_optimizer.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(self.model.get_critic_params(), 0.5)
        self.critic_optimizer.step()
        # self.critic_scheduler.step()
        return cost

    def sync_target(self, decay=None):
        if decay is None:
            decay = 1.0 - self.tau
        self.model.sync_weights_to(self.target_model, decay=decay)

