#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import parl
import torch
import torch.nn as nn
import torch.nn.functional as F


# clamp bounds for Std of action_log
LOG_SIG_MAX = 2.0
LOG_SIG_MIN = -20.0

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)


class MAModel(parl.Model):
    def __init__(self, obs_dim, act_dim):
        super(MAModel, self).__init__()  #继承MAModel父类（即parl.Model类）的对象
        self.actor_model = ActorModel(obs_dim, act_dim)
        self.critic_model = CriticModel(obs_dim)

    def policy(self, obs):
        return self.actor_model(obs)

    def value(self, obs_n):
        return self.critic_model(obs_n)

    def get_actor_params(self):
        return self.actor_model.parameters()

    def get_critic_params(self):
        return self.critic_model.parameters()


class ActorModel(parl.Model):    ###PLAN A
    def __init__(self, obs_dim, act_dim):
        super(ActorModel, self).__init__()
        self.act_dim = act_dim
        hid1_size = 64
        hid2_size = 64
        self.fc1 = nn.Linear(obs_dim, hid1_size)
        self.fc2 = nn.Linear(hid1_size, hid2_size)
        self.fc3 = nn.Linear(hid2_size, self.act_dim)
        self.log_std = nn.Parameter(torch.zeros(act_dim))
        self.apply(weights_init_)
        
    def forward(self, obs):
        hid1 = F.relu(self.fc1(obs))  #torch.tanh
        hid2 = F.relu(self.fc2(hid1))
        mean = self.fc3(hid2)
        return mean, self.log_std

# class ActorModel(parl.Model):    ###PLAN B  #收敛太慢
#     def __init__(self, obs_dim, act_dim):
#         super(ActorModel, self).__init__()
#         self.act_dim = act_dim
#         hid1_size = 64
#         hid2_size = 64
#         self.fc1 = nn.Linear(obs_dim, hid1_size)
#         self.fc2 = nn.Linear(hid1_size, hid2_size)
#         self.fc3 = nn.Linear(hid1_size, self.act_dim)
#         self.mean_fc3 = nn.Linear(hid2_size, self.act_dim)
#         self.std_fc3 = nn.Linear(hid2_size, self.act_dim)
#         self.apply(weights_init_)
        
#     def forward(self, obs):
#         hid1 = F.relu(self.fc1(obs))  #torch.tanh
#         hid2 = F.relu(self.fc2(hid1))
#         mean = self.mean_fc3(hid2)
#         std = self.std_fc3(hid2)
#         log_std = torch.clamp(std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)  #控制exp(log_std)值不会太大
#         return mean, log_std


class CriticModel(parl.Model):
    def __init__(self, obs_dim):
        super(CriticModel, self).__init__()
        hid1_size = 64
        hid2_size = 64
        out_dim = 1
        self.fc1 = nn.Linear(obs_dim, hid1_size)
        self.fc2 = nn.Linear(hid1_size, hid2_size)
        self.fc3 = nn.Linear(hid2_size, out_dim)
        self.apply(weights_init_)

    def forward(self, obs):
        hid1 = F.relu(self.fc1(obs))
        hid2 = F.relu(self.fc2(hid1))
        V = self.fc3(hid2)
        V = torch.squeeze(V, dim=1)  #去掉第二个维度 [batchsize]
        return V
    
    # def forward(self, obs_n):
    #     #concat agent_num 2Dtensors  #agent_num * [batchsize, obs_shape] --> [batchsize, agent_num*obs_shape]
    #     inputs = torch.cat(obs_n, dim=1)
    #     hid1 = F.relu(self.fc1(inputs))
    #     hid2 = F.relu(self.fc2(hid1))
    #     V = self.fc3(hid2)
    #     V = torch.squeeze(V, dim=1)  #去掉第二个维度 [batchsize]
    #     return V
