#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import parl
import torch
import numpy as np
import copy
from parl.utils import ReplayMemory
from parl.utils import machine_info, get_gpu_count


class MAAgent(parl.Agent):
    def __init__(self,
                 algorithm,
                 agent_index=None,
                 obs_dim_n=None,
                 act_dim_n=None,
                 memory_size=None,
                 min_memory_size=None,
                 batch_size=None,
                 miu_lr=None,
                 speedup=False):
        assert isinstance(agent_index, int)
        assert isinstance(obs_dim_n, list)
        assert isinstance(act_dim_n, list)
        assert isinstance(memory_size, int)
        assert isinstance(batch_size, int)
        assert isinstance(miu_lr, float)
        assert isinstance(speedup, bool)
        self.agent_index = agent_index
        self.obs_dim_n = obs_dim_n
        self.act_dim_n = act_dim_n
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.speedup = speedup
        self.n = len(act_dim_n)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.Mu_lr = miu_lr
        self._lr = 0.
        self.m_lr = 0.
        self.miu = torch.zeros(batch_size)
        self.miu = torch.zeros(batch_size)
        self.miu_i = torch.zeros(batch_size)
        self.miu_Ni = torch.zeros(batch_size)
        self.prev_miu = torch.zeros(batch_size)
        self.prev_miu_i = torch.zeros(batch_size)
        self.prev_miu_Ni = torch.zeros(batch_size)
        self.delta_miu = torch.zeros(batch_size)
        self.delta_miu_i = torch.zeros(batch_size)

        self.min_memory_size = min_memory_size  # batch_size * max_episode_len  # warm up stage
        self.rpm = ReplayMemory(
            max_size=self.memory_size,
            obs_dim=self.obs_dim_n[agent_index],
            act_dim=self.act_dim_n[agent_index],
            num_agents=self.n)
        self.global_train_step = 0

        super(MAAgent, self).__init__(algorithm)  #继承MAAgent父类（即parl.Agent类）的对象

        # Attention: In the beginning, sync target model totally.
        self.alg.sync_target(decay=0)

    def predict(self, obs, use_target_model=False):
        """ predict action by model or target_model
        """
        obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device)  #变换维数为1行，x列  #[batchsize=1, obs_n_dim]
        act, act_apply = self.alg.predict(obs, use_target_model=use_target_model)
        act_numpy = act.detach().cpu().numpy().flatten()
        act_apply_numpy = act_apply.detach().cpu().numpy().flatten()
        return act_numpy, act_apply_numpy

    def learn(self, agents, rpm_sample_index):
        """ sample batch, compute q_target and train
        """
        # self.global_train_step += 1

        # only update logger's parameter every 100 steps
        if self.global_train_step % 100 != 0:
            return 0.0, 0.0

        # warm up stage
        if self.rpm.size() < self.min_memory_size:  #<=
            return 0.0, 0.0

        # sample batch
        batch_obs_n = []
        batch_act_n = []
        batch_rew_n = []
        batch_obs_next_n = []
        # rpm_sample_index = self.rpm.make_index(self.batch_size)
        batch_obs_i, batch_act_i, batch_neigh_i, batch_rew_i, batch_obs_next_i, batch_isOver = \
            self.rpm.sample_batch_by_index(rpm_sample_index)  #self: agent_i
        
        batch_obs_i = torch.FloatTensor(batch_obs_i).to(self.device)  #tensor: [batchsize, obs_dim]
        batch_act_i = torch.FloatTensor(batch_act_i).to(self.device)  #tensor: [batchsize, act_dim]
        batch_rew_i = torch.FloatTensor(batch_rew_i).to(self.device)  #tensor: [batchsize]
        batch_obs_next_i = torch.FloatTensor(batch_obs_next_i).to(self.device)
        batch_isOver = torch.FloatTensor(batch_isOver).to(self.device)  #tensor: [batchsize]

        # batch_rew_global = torch.zeros(self.batch_size)    #tensor: [batchsize]
        # for i in range(self.n):
        #     batch_rew_global += batch_rew_n[i]
        # batch_rew_global /= self.n
        
        # compute target v
        v_next_target = self.alg.V(batch_obs_next_i, use_target_model = True)  #tensor:[batchsize]
        # local_target_v = (batch_rew_i - self.miu) + \
        #     self.alg.gamma * (1.0 - batch_isOver) * v_next_target.detach() - \
        #     self.alg.V(batch_obs_i, use_target_model = True).detach()  #WRONG
        # local_target_v = (batch_rew_i) + \
        #     self.alg.gamma * (1.0 - batch_isOver) * v_next_target.detach()  #TEST
        # local_target_v = (batch_rew_i - self.miu) + \
        #     self.alg.gamma * (1.0 - batch_isOver) * v_next_target.detach()

        uneq = torch.sign(self.delt_i) != torch.sign(self.delt_Ei)
        eq = ~ uneq


        # diff_rew = self.bf_delta_miu_i

        # diff_rew = self.delta_miu_i
        # diff_rew = self.delta_miu_Ni

        diff_rew = eq * (self.delta_miu_i) + \
                uneq * (self.delta_miu_Ni)


        advantage = diff_rew + \
            self.alg.gamma * (1.0 - batch_isOver) * v_next_target.detach() - \
            self.alg.V(batch_obs_i)  #1-step TD  #tensor:[batchsize]
        # advantage = diff_rew + \
        #     self.alg.gamma * (1.0 - batch_isOver) * v_next_target.detach() - \
        #     self.alg.V(batch_obs_i, use_target_model = True).detach() - \
        #     self.alg.V(batch_obs_i)  ###TEST WRONG


        # learn
        critic_cost, actor_cost = self.alg.learn(batch_obs_i, batch_act_i, advantage)
        critic_cost = critic_cost.cpu().detach().numpy()
        actor_cost = actor_cost.cpu().detach().numpy()

        #### update average reward
        # self.miu += self.miu_lr * (torch.mean(batch_rew_i) - self.miu)    #tensor:[1]  #without grad
        # self.miu += self.miu_lr * (batch_rew_i - self.miu)


        return critic_cost, actor_cost
    
    def calc_miu(self, rpm_sample_index, agents, miu_i_N):
        self.global_train_step += 1

        # only update parameter every 100 steps
        if self.global_train_step % 100 != 0:
            return False

        # warm up stage
        if self.rpm.size() <= self.min_memory_size:
            return False

        self._lr += self.Mu_lr * (1 - self._lr)  # MIU_LR = 3e-4
        self.m_lr = self.Mu_lr / self._lr
        
        batch_rew_n = []
        for i in range(self.n):
            batch_obs, batch_act, batch_neigh, batch_rew, batch_obs_next, _ \
                = agents[i].rpm.sample_batch_by_index(rpm_sample_index)
            batch_rew_n.append(batch_rew)
        batch_rew_n = torch.tensor(batch_rew_n).transpose(0,1)  #tensor: [batchsize, n]
        # batch_rew_n = [  #tensor list: n * [batchsize]
        #     torch.FloatTensor(rew).to(self.device) for rew in batch_rew_n
        # ]
        
        batch_obs_i, batch_act_i, batch_neigh_i, batch_rew_i, batch_obs_next_i, batch_isOver = \
            self.rpm.sample_batch_by_index(rpm_sample_index)  #self: agent_i
        batch_rew_i = torch.FloatTensor(batch_rew_i).to(self.device)  #tensor: [batchsize]
        batch_neigh_i = torch.FloatTensor(batch_neigh_i).to(self.device)  #tensor: [batchsize, n]
        batch_Ni_i = torch.count_nonzero(batch_neigh_i, dim=-1)  #tensor: [batchsize]
        # print(batch_neigh_i[1:4], batch_Ni_i[1:4])
        self.count_Ni = batch_Ni_i
        # print(self.agent_index, batch_neigh_i.size())
        
        batch_rew_ni = torch.sum(batch_neigh_i * batch_rew_n, dim=-1) / batch_Ni_i  #tensor: [batchsize]
        batch_rew_ei = (torch.sum(batch_neigh_i * batch_rew_n, dim=-1) - batch_rew_i) / (batch_Ni_i - 1)


        self.prev_miu_i = copy.deepcopy(self.miu_i)
        self.prev_miu = copy.deepcopy(self.miu)


        self.prev_miu_Ni = torch.sum(batch_neigh_i * miu_i_N, dim=-1) / batch_Ni_i
        # self.prev_miu_Ni = torch.mean(miu_i_N, dim=-1)   ###TEST


        self.miu_i += self.m_lr * (batch_rew_i - self.miu_i)
        self.miu += self.m_lr * (batch_rew_i - self.miu)

        # self.bf_delta_miu_i = batch_rew_i - self.prev_miu_i
        # self.bf_delta_miu = batch_rew_i - self.prev_miu
        # self.bf_delta_miu_i = (self.miu_i - self.prev_miu_i) / self.m_lr
        # self.bf_delta_miu = (self.miu - self.prev_miu) / self.m_lr

        # self.delta_miu_i = batch_rew_i - self.prev_miu
        # self.delta_miu_Ei = batch_rew_ei - self.prev_miu
        # self.delta_miu_Ni = batch_rew_ni - self.prev_miu
        self.delta_miu_i = batch_rew_i - self.prev_miu_Ni
        self.delta_miu_Ei = batch_rew_ei - self.prev_miu_Ni
        self.delta_miu_Ni = batch_rew_ni - self.prev_miu_Ni
        muNi = torch.mean(self.prev_miu_Ni)
        self.delt_i = batch_rew_i - muNi
        self.delt_Ei = batch_rew_ei - muNi
        self.delt_Ni = batch_rew_ni - muNi

        self.bf_delta_miu_i = batch_rew_i - self.prev_miu_i
        self.af_delta_miu_i = batch_rew_ni - self.prev_miu_i


        return True


    def add_experience(self, obs, act, neigh, reward, next_obs, terminal):
        self.rpm.append(obs, act, neigh, reward, next_obs, terminal)
