#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import parl
import torch
import numpy as np
import copy
from parl.utils import ReplayMemory
from parl.utils import machine_info, get_gpu_count


class MAAgent(parl.Agent):
    def __init__(self,
                 algorithm,
                 agent_index=None,
                 obs_dim_n=None,
                 act_dim_n=None,
                 memory_size=None,
                 min_memory_size=None,
                 batch_size=None,
                 speedup=False):
        assert isinstance(agent_index, int)
        assert isinstance(obs_dim_n, list)
        assert isinstance(act_dim_n, list)
        assert isinstance(memory_size, int)
        assert isinstance(batch_size, int)
        assert isinstance(speedup, bool)
        self.agent_index = agent_index
        self.obs_dim_n = obs_dim_n
        self.act_dim_n = act_dim_n
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.speedup = speedup
        self.n = len(act_dim_n)

        # self.miu_lr = miu_lr
        self._lr = 0.
        self.m_lr = 0.
        # self.miu = torch.zeros(batch_size)
        self.miu = torch.zeros(batch_size)
        self.miu_i = torch.zeros(batch_size)
        self.miu_Ni = torch.zeros(batch_size)
        self.prev_miu = torch.zeros(batch_size)
        self.prev_miu_i = torch.zeros(batch_size)
        self.prev_miu_Ni = torch.zeros(batch_size)
        self.delta_miu = torch.zeros(batch_size)
        self.delta_miu_i = torch.zeros(batch_size)

        self.min_memory_size = min_memory_size  # batch_size * max_episode_len  # warm up stage
        self.rpm = ReplayMemory(
            max_size=self.memory_size,
            obs_dim=self.obs_dim_n[agent_index],
            act_dim=self.act_dim_n[agent_index],
            num_agents=self.n)
        self.global_train_step = 0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        super(MAAgent, self).__init__(algorithm)  #继承MAAgent父类（即parl.Agent类）的对象

        # Attention: In the beginning, sync target model totally.
        self.alg.sync_target(decay=0)

    def predict(self, obs, use_target_model=False):
        """ predict action by model or target_model
        """
        obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device)  #改变维数为1行，x列
        act = self.alg.predict(obs, use_target_model=use_target_model)
        act_numpy = act.detach().cpu().numpy().flatten()
        return act_numpy

    def learn(self, agents, rpm_sample_index):
        """ sample batch, compute q_target and train
        """
        # self.global_train_step += 1

        # only update parameter every 100 steps
        if self.global_train_step % 100 != 0:
            return 0.0, 0.0

        # warm up stage
        if self.rpm.size() <= self.min_memory_size:
            return 0.0, 0.0

        batch_obs_n = []
        batch_act_n = []
        batch_rew_n = []  ###
        batch_obs_next_n = []
        # sample batch
        # rpm_sample_index = self.rpm.make_index(self.batch_size)
        batch_obs_i, batch_act_i, batch_neigh_i, batch_rew_i, batch_obs_next_i, batch_isOver = \
            self.rpm.sample_batch_by_index(rpm_sample_index)  #self: agent_i

        batch_obs_i = torch.FloatTensor(batch_obs_i).to(self.device)  #tensor: [batchsize, obs_dim]
        batch_act_i = torch.FloatTensor(batch_act_i).to(self.device)  #tensor: [batchsize, act_dim]
        batch_rew_i = torch.FloatTensor(batch_rew_i).to(self.device)  #tensor: [batchsize]
        batch_obs_next_i = torch.FloatTensor(batch_obs_next_i).to(self.device)
        batch_isOver = torch.FloatTensor(batch_isOver).to(self.device)  #tensor: [batchsize]
        
        # batch_rew_global = torch.zeros(self.batch_size)    #tensor: [batchsize]
        # for i in range(self.n):
        #     batch_rew_global += batch_rew_n[i]
        # batch_rew_global /= self.n

        target_act_next_i = self.alg.predict(batch_obs_next_i, use_target_model=True)
        target_q_next = self.alg.Q(batch_obs_next_i, target_act_next_i, use_target_model=True)  #tensor:[batchsize]


        '''Test'''
        # ind = self.count_Ni == 1
        # mul = ~ ind

        # pos = self.delt_Ni > 0
        # neg = ~ pos

        uneq = torch.sign(self.delt_i) != torch.sign(self.delt_Ei)
        eq = ~ uneq
        
        # diff_rew = self.bf_delta_miu_i
        # diff_rew = self.delta_miu_i
        # diff_rew = self.delta_miu_Ni

        diff_rew = eq * (self.delta_miu_i) + \
                uneq * (self.delta_miu_Ni)
        # diff_rew = eq * (self.bf_delta_miu_i) + \
        #         uneq * (self.af_delta_miu_i)
        # diff_rew = ind * (self.delta_miu_i) + \
        #     mul * (
        #         eq * (pos * self.delta_miu_i + neg * self.delta_miu_i_ne) + \
        #         uneq * (self.delta_miu_Ni)
        #     )
        
        
        # self.miu_i = self.prev_miu + self.m_lr * (
        #              eq * (self.bf_delta_miu) + \
        #              uneq * (self.delta_miu)
        # )
        # self.miu_i = self.prev_miu_i + self.m_lr * (
        #              eq * (self.bf_delta_miu_i) + \
        #              uneq * (self.af_delta_miu_i)
        # )  ## Should Keep miu_i Unreshaped


        target_q = diff_rew + self.alg.gamma * (1.0 - batch_isOver) * target_q_next.detach()  #TEST
        # target_q = batch_rew_i + self.alg.gamma * (1.0 - batch_isOver) * target_q_next.detach()
        # target_q = batch_rew_n[i] + self.alg.gamma * (1.0 - batch_isOver) * target_q_next.detach()
        
        # learn
        critic_cost, actor_cost = self.alg.learn(batch_obs_i, batch_act_i, target_q)
        critic_cost = critic_cost.cpu().detach().numpy()
        actor_cost = actor_cost.cpu().detach().numpy()

        # self.miu += self.miu_lr * (torch.mean(batch_rew_i) - self.miu)  #TEST
        # self.miu += 1e-2 * (batch_rew_global - self.miu)  #TEST

        return critic_cost, actor_cost

    def calc_miu(self, rpm_sample_index, agents, miu_i_N):
        self.global_train_step += 1

        # only update parameter every 100 steps
        if self.global_train_step % 100 != 0:
            return False

        # warm up stage
        if self.rpm.size() <= self.min_memory_size:
            return False

        self._lr += 1e-2 * (1 - self._lr)  # MIU_LR = 1e-2
        self.m_lr = 1e-2 / self._lr
        
        batch_rew_n = []
        for i in range(self.n):
            batch_obs, batch_act, batch_neigh, batch_rew, batch_obs_next, _ \
                = agents[i].rpm.sample_batch_by_index(rpm_sample_index)
            batch_rew_n.append(batch_rew)
        batch_rew_n = torch.tensor(batch_rew_n).transpose(0,1)  #tensor: [batchsize, n]
        # batch_rew_n = [  #tensor list: n * [batchsize]
        #     torch.FloatTensor(rew).to(self.device) for rew in batch_rew_n
        # ]
        
        batch_obs_i, batch_act_i, batch_neigh_i, batch_rew_i, batch_obs_next_i, batch_isOver = \
            self.rpm.sample_batch_by_index(rpm_sample_index)  #self: agent_i
        batch_rew_i = torch.FloatTensor(batch_rew_i).to(self.device)  #tensor: [batchsize]
        batch_neigh_i = torch.FloatTensor(batch_neigh_i).to(self.device)  #tensor: [batchsize, n]
        batch_Ni_i = torch.count_nonzero(batch_neigh_i, dim=-1)  #tensor: [batchsize]
        # print(batch_neigh_i[1:4], batch_Ni_i[1:4])
        self.count_Ni = batch_Ni_i
        # print(self.agent_index, batch_neigh_i.size())
        
        batch_rew_ni = torch.sum(batch_neigh_i * batch_rew_n, dim=-1) / batch_Ni_i  #tensor: [batchsize]
        batch_rew_ei = (torch.sum(batch_neigh_i * batch_rew_n, dim=-1) - batch_rew_i) / (batch_Ni_i - 1)
        
        # batch_rew_ni = torch.zeros(self.batch_size)  #tensor: [batchsize]
        # batch_rew_ei = torch.zeros(self.batch_size)
        # for i in range(self.n):
        #     batch_rew_ni += batch_rew_n[i]
        #     if i == self.agent_index :continue
        #     batch_rew_ei += batch_rew_n[i]
        # batch_rew_ni /= self.n
        # batch_rew_ei /= (self.n-1)
        

        self.prev_miu_i = copy.deepcopy(self.miu_i)
        self.prev_miu = copy.deepcopy(self.miu)
        
        
        self.prev_miu_Ni = torch.sum(batch_neigh_i * miu_i_N, dim=-1) / batch_Ni_i
        
        # miu_i_N = torch.zeros_like(batch_neigh_i)
        # for i, agent in enumerate(agents):
        #     miu_i_N[:, i] = agent.prev_miu_i
        # self.prev_miu_Ni = torch.sum(batch_neigh_i * miu_i_N, dim=-1) / batch_Ni_i
        

        self.miu_i += self.m_lr * (batch_rew_i - self.miu_i)
        self.miu += self.m_lr * (batch_rew_i - self.miu)

        # self.bf_delta_miu_i = batch_rew_i - self.prev_miu_i
        # self.bf_delta_miu = batch_rew_i - self.prev_miu
        # self.bf_delta_miu_i = (self.miu_i - self.prev_miu_i) / self.m_lr
        # self.bf_delta_miu = (self.miu - self.prev_miu) / self.m_lr

        # self.delta_miu_i = batch_rew_i - self.prev_miu
        # self.delta_miu_Ei = batch_rew_ei - self.prev_miu
        # self.delta_miu_Ni = batch_rew_ni - self.prev_miu
        self.delta_miu_i = batch_rew_i - self.prev_miu_Ni
        self.delta_miu_Ei = batch_rew_ei - self.prev_miu_Ni
        self.delta_miu_Ni = batch_rew_ni - self.prev_miu_Ni
        muNi = torch.mean(self.prev_miu_Ni)
        self.delt_i = batch_rew_i - muNi
        self.delt_Ei = batch_rew_ei - muNi
        self.delt_Ni = batch_rew_ni - muNi

        self.bf_delta_miu_i = batch_rew_i - self.prev_miu_i
        self.af_delta_miu_i = batch_rew_ni - self.prev_miu_i



        return True

    def add_experience(self, obs, act, neigh, reward, next_obs, terminal):
        self.rpm.append(obs, act, neigh, reward, next_obs, terminal)
