#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import parl
import torch
import numpy as np
from parl.utils import ReplayMemory
from parl.utils import machine_info, get_gpu_count


class MAAgent(parl.Agent):
    def __init__(self,
                 algorithm,
                 agent_index=None,
                 obs_dim_n=None,
                 act_dim_n=None,
                 memory_size=None,
                 min_memory_size=None,
                 batch_size=None,
                 speedup=False):
        assert isinstance(agent_index, int)
        assert isinstance(obs_dim_n, list)
        assert isinstance(act_dim_n, list)
        assert isinstance(memory_size, int)
        assert isinstance(batch_size, int)
        assert isinstance(speedup, bool)
        self.agent_index = agent_index
        self.obs_dim_n = obs_dim_n
        self.act_dim_n = act_dim_n
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.speedup = speedup
        self.n = len(act_dim_n)

        # self.miu = torch.zeros(1)  #TEST
        self.miu = torch.zeros(batch_size)  #TEST
        self._lr = 0.

        self.min_memory_size = min_memory_size # batch_size * args.max_episode_len  # warm up stage
        self.rpm = ReplayMemory(
            max_size=self.memory_size,
            obs_dim=self.obs_dim_n[agent_index],
            act_dim=self.act_dim_n[agent_index],
            num_agents=self.n)
        self.global_train_step = 0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        super(MAAgent, self).__init__(algorithm)  #继承MAAgent父类（即parl.Agent类）的对象

        # Attention: In the beginning, sync target model totally.
        self.alg.sync_target(decay=0)

    def predict(self, obs, use_target_model=False):
        """ predict action by model or target_model
        """
        obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device)  #改变维数为1行，x列
        act = self.alg.predict(obs, use_target_model=use_target_model)
        act_numpy = act.detach().cpu().numpy().flatten()
        return act_numpy

    def learn(self, agents, rpm_sample_index):
        """ sample batch, compute q_target and train
        """
        self.global_train_step += 1

        # only update parameter every 100 steps
        if self.global_train_step % 100 != 0:
            return 0.0, 0.0

        # warm up stage
        if self.rpm.size() <= self.min_memory_size:
            return 0.0, 0.0

        self._lr += 1e-2 * (1 - self._lr)  # MIU_LR = 1e-2
        self.m_lr = 1e-2 / self._lr
        
        m = 0
        batch_obs_n = []
        batch_act_n = []
        batch_rew_n = []  ###
        batch_obs_next_n = []

        # sample batch
        # rpm_sample_index = self.rpm.make_index(self.batch_size)
        for i in range(self.n-m):
            #ndarray list:
            #batch_obs/act_n: n * [batchsize, obs/act_dim]
            #rew_n: n * [batchsize]
            batch_obs, batch_act, batch_neigh, batch_rew, batch_obs_next, _ \
                = agents[i].rpm.sample_batch_by_index(rpm_sample_index)
            batch_obs_n.append(batch_obs)
            batch_act_n.append(batch_act)
            batch_rew_n.append(batch_rew)
            batch_obs_next_n.append(batch_obs_next)
        # _, _, batch_rew_i, _, batch_isOver = self.rpm.sample_batch_by_index(rpm_sample_index)
        _, _, batch_neigh_i, batch_rew_i, _, batch_isOver = self.rpm.sample_batch_by_index(rpm_sample_index)
        batch_obs_n = [  #tensor list  # n*[batchsize, obs_n]
            torch.FloatTensor(obs).to(self.device) for obs in batch_obs_n
        ]
        batch_act_n = [  #tensor list  # n*[batchsize, act_n]
            torch.FloatTensor(act).to(self.device) for act in batch_act_n
        ]
        batch_rew_n = [  #tensor list: n * [batchsize]
            torch.FloatTensor(rew).to(self.device) for rew in batch_rew_n
        ]
        batch_rew_i = torch.FloatTensor(batch_rew_i).to(self.device)  #[batchsize]
        batch_isOver = torch.FloatTensor(batch_isOver).to(self.device)  #tensor:[batchsize]
        
        batch_rew_global = torch.zeros(self.batch_size)    #tensor: [batchsize]
        for i in range(self.n-m):
            batch_rew_global += batch_rew_n[i]
        # batch_rew_global /= (self.n-m)
        # if self.agent_index == 3:
        #     batch_rew_global = batch_rew_n[self.agent_index]
        # else:
        #     for i in range(self.n - 1):
        #         batch_rew_global += batch_rew_n[i]
        #     batch_rew_global /= (self.n - 1)
        

        # compute target q
        target_act_next_n = []
        batch_obs_next_n = [
            torch.FloatTensor(obs).to(self.device) for obs in batch_obs_next_n
        ]
        for i in range(self.n-m):
            target_act_next = agents[i].alg.predict(
                batch_obs_next_n[i], use_target_model=True)
            target_act_next = target_act_next.detach()
            target_act_next_n.append(target_act_next)
        target_q_next = self.alg.Q(batch_obs_next_n, target_act_next_n, use_target_model=True)  #tensor:[batchsize]
        
        # target_q = (batch_rew_global - self.miu) + self.alg.gamma * (1.0 - batch_isOver) * target_q_next.detach()  #TEST
        # target_q = (batch_rew_i - self.miu) + self.alg.gamma * (1.0 - batch_isOver) * target_q_next.detach()
        target_q = batch_rew_i + self.alg.gamma * (1.0 - batch_isOver) * target_q_next.detach()
        # target_q = batch_rew_n[i] + self.alg.gamma * (1.0 - batch_isOver) * target_q_next.detach()

        # learn
        critic_cost, actor_cost = self.alg.learn(batch_obs_n, batch_act_n, target_q)
        critic_cost = critic_cost.cpu().detach().numpy()
        actor_cost = actor_cost.cpu().detach().numpy()

        self.miu += self.m_lr * (batch_rew_i - self.miu)  #TEST
        # self.miu += 1e-2 * (torch.mean(batch_rew_global) - self.miu)  #TEST

        return critic_cost, actor_cost

    def add_experience(self, obs, act, neigh, reward, next_obs, terminal):
        self.rpm.append(obs, act, neigh, reward, next_obs, terminal)
