#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import argparse
import numpy as np
import torch
from maac_model import MAModel
from maac_agent import MAAgent
from parl.algorithms import MAAC
from parl.env.multiagent_simple_env import MAenv
from parl.utils import logger, summary


MIU_LR = 3e-4  # learning rate of the average reward
CRITIC_LR = 3e-4  # learning rate for the critic model
ACTOR_LR = 2e-4  # learning rate of the actor model
GAMMA = 0.98  # reward discount factor
ALPHA = 0.  # SoftAC coefficient  #Determines the relative importance of entropy term against the reward

"""
For this Alg. MEMORY_SIZE should be equivalent to BATCHSIZE * EPISODE_LEN
to ensure Actor on-time update"""
MEMORY_SIZE = int(4e3)  #4e3
BATCH_SIZE = 128  #128
TAU = 0.005  # soft update  # smaller tau, slower target_model update

MAX_EPISODES = 50e4  # stop condition:number of episodes
MAX_STEP_PER_EPISODE = 30  # maximum step per episode
MIN_MEMORY_SIZE = BATCH_SIZE * MAX_STEP_PER_EPISODE
STAT_RATE = 1000  # statistical interval of save model or count reward

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


_lr = 0.
def run_episode(env, agents, CT_trner):
    global BATCH_SIZE
    global MAX_STEP_PER_EPISODE
    global _lr
    if args.benchmark or args.show:
        MAX_STEP_PER_EPISODE = 30

    obs_n = env.reset()
    total_reward = 0
    agents_reward = [0 for _ in range(env.n)]
    agent_info = [[] for _ in range(env.n)]  # placeholder for benchmarking info
    group_info = np.zeros(2)
    ep_info = []
    ret_info = 0
    steps = 0

    while True:
        steps += 1
        # action_n, action_apply_n = [agent.predict(obs) for agent, obs in zip(agents, obs_n)]
        action_n = []
        action_apply_n = []
        for agent, obs in zip(agents, obs_n):
            act_n, act_apply_n = agent.predict(obs)
            action_n.append(act_n)
            action_apply_n.append(act_apply_n)
        next_obs_n, neigh_n, reward_n, done_n, info_n = env.step(action_apply_n)
        done = all(done_n)
        # if done: print(steps)
        # done = done_n[-1]
        # done = (done_n[-1] and done_n[-2])

        terminal = (steps >= MAX_STEP_PER_EPISODE)
        if done or terminal:
            ret_info = np.sum(info_n)

        # store experience
        for i, agent in enumerate(agents[:CT_trner]):
            agent.add_experience(obs_n[i], action_n[i], neigh_n[i], reward_n[i], next_obs_n[i], done_n[i])

        # compute reward of every agent
        obs_n = next_obs_n
        for i, reward in enumerate(reward_n):
            total_reward += reward
            agents_reward[i] += reward

        # benchmarking learned policies
        if args.benchmark:
            # agent_info = agents_reward
            # print(info_n)
            ep_info.append(np.array(info_n[0]))
            if done or terminal:
                # print(np.array(info_n[0]))
                # rec_info = np.array(agent_info)
                # rec_info = rec_info.reshape(2,-1)
                # group_info[0] = np.mean(rec_info[0])
                # group_info[1] = np.mean(rec_info[1])
                ret_info = (np.sum(ep_info, axis = 0)[0], np.mean(ep_info, axis = 0)[1])
                break
            continue
        
        # check the end of an episode
        if done or terminal:
            break

        # show animation
        if args.show:
            time.sleep(0.1)
            env.render()

        # show model effect without training
        if args.restore and (args.show or args.benchmark):
            continue

        # ========== learn policy ==========
        
        # calculate average reward
        # rpm_sample_index = agents[0].rpm.make_index(BATCH_SIZE)
        rpm_sample_index = []
        agents_miu_i = torch.zeros(BATCH_SIZE, env.n)
        for i, agent in enumerate(agents[:CT_trner]):
            rpm_sample_index.append(agent.rpm.make_index(BATCH_SIZE))
            agents_miu_i[:, i] = agent.miu_i
        for i, agent in enumerate(agents[:CT_trner]):
            if done_n[i]: continue
            _ = agent.calc_miu(rpm_sample_index[0], agents, agents_miu_i)

        consensus = True
        if consensus and _:
            agg_miu = torch.zeros(BATCH_SIZE)
            m = env.n  # int(env.n / 2)
            for agent in agents[:CT_trner]:  # [m:]
                agg_miu += agent.miu
            for agent in agents[:CT_trner]:
                agent.miu = agg_miu / m

        # update Induvidual NN parameters
        for i, agent in enumerate(agents[:CT_trner]):
            if done_n[i]: continue
            critic_loss, actor_loss = agent.learn(agents, rpm_sample_index[0])
            if i in [1, env.n-1] and critic_loss != 0.0 and actor_loss != 0.0:
                miu_rec = agent.miu[0].cpu().detach().numpy()
                miu_i_rec = agent.miu_i[0].cpu().detach().numpy()
                summary.add_scalar('agent_%d_loss/critic_loss' % i, critic_loss, agent.global_train_step)
                summary.add_scalar('agent_%d_loss/actor_loss' % i, actor_loss, agent.global_train_step)
                summary.add_scalars(
                        "agents_average_reward", 
                        {'agent_%d' % i: miu_rec}, 
                        agent.global_train_step)
                summary.add_scalars(
                        "agent_average_reward", 
                        {'agent_%d' % i: miu_i_rec}, 
                        agent.global_train_step)


    return total_reward, agents_reward, ret_info, steps


def train_agent():
    env = MAenv(args.env, args.benchmark)
    logger.info('agent num: {}'.format(env.n))
    logger.info('observation_space: {}'.format(env.observation_space))
    logger.info('action_space: {}'.format(env.action_space))
    logger.info('obs_shape_n: {}'.format(env.obs_shape_n))
    logger.info('act_shape_n: {}'.format(env.act_shape_n))

    for i in range(env.n):
        logger.info('agent {} obs_low:{} obs_high:{}'.format(
            i, env.observation_space[i].low, env.observation_space[i].high))
        logger.info('agent {} act_n:{}'.format(i, env.act_shape_n[i]))
        if ('low' in dir(env.action_space[i])):
            logger.info('agent {} act_low:{} act_high:{} act_shape:{}'.format(
                i, env.action_space[i].low, env.action_space[i].high, env.action_space[i].shape))
            logger.info('num_discrete_space:{}'.format(
                env.action_space[i].num_discrete_space))

    from gym import spaces
    from multiagent.multi_discrete import MultiDiscrete
    for space in env.action_space:
        assert (isinstance(space, spaces.Discrete) or isinstance(space, MultiDiscrete))

    # critic_in_dim = sum(env.obs_shape_n)
    critic_in_dim = env.obs_shape_n[0]
    logger.info('critic_in_dim: {}'.format(critic_in_dim))

    agents = []
    # for i in range(env.n):
    for i, env_agent in enumerate(env.agents):
        model = MAModel(
            env.obs_shape_n[i], 
            env.act_shape_n[i])
        algorithm = MAAC(
            model,
            agent_index=i,
            gamma=GAMMA,
            tau=TAU,
            alpha=ALPHA,
            critic_lr=CRITIC_LR,
            actor_lr=ACTOR_LR)
        agent = MAAgent(
            algorithm,
            agent_index=i,
            obs_dim_n=env.obs_shape_n,
            act_dim_n=env.act_shape_n,
            memory_size=MEMORY_SIZE,
            min_memory_size=MIN_MEMORY_SIZE,
            batch_size=BATCH_SIZE,
            miu_lr=MIU_LR,
            speedup=(not args.restore))
        agents.append(agent)
    
    ### ###
    CT_pretrned = 0
    if args.use_pretrained:
        for i, agent in enumerate(env.agents):
            if agent.tp == 9:
                CT_pretrned += 1
                model_file = args.model_dir + '/agent_' + str(agent.id)
                agents[i].restore(model_file)
    CT_trner = len(agents) - CT_pretrned
    ### ###
    
    total_steps = 0
    total_episodes = 0

    if args.benchmark:
        MAX_EP = 10000
        STAT_RT = 500
    else:
        MAX_EP = MAX_EPISODES
        STAT_RT = STAT_RATE

    episode_rewards = [None] * STAT_RT  # sum of rewards for all agents
    agent_rewards = [[None] * STAT_RT for _ in range(env.n)]  # individual agent reward
    rec_info = [None] * STAT_RT # record benchmark info


    if args.restore:
        # restore model
        for i in range(len(agents)):
            model_file = args.model_dir + '/agent_' + str(i)
            if not os.path.exists(model_file):
                raise Exception('model file {} does not exits'.format(model_file))
            agents[i].restore(model_file)

    
    t_start = time.time()
    logger.info('Starting...')
    while total_episodes <= MAX_EP:
        # run an episode
        ep_reward, ep_agent_rewards, info, steps = run_episode(env, agents, CT_trner)
        # summary.add_scalar('train_reward/episode', ep_reward, total_episodes)
        # summary.add_scalar('train_reward/step', ep_reward, total_steps)
        if args.show:
            m = int(env.n / 2)
            # print('episode {}, reward {}, group reward {}, agents rewards {}, steps {}'.format(
            #     np.round(total_episodes, 2), np.round(ep_reward, 2), 
            #     [np.round(np.mean(ep_agent_rewards[:-m]), 2), np.round(np.mean(ep_agent_rewards[m:]), 2)], 
            #     np.round(ep_agent_rewards, 2), steps))
            # print([env.agents[0].live, env.agents[1].live, env.agents[2].live, env.agents[3].live], 
            #     [env.agents[4].live, env.agents[5].live, env.agents[6].live, env.agents[7].live])
            # print([env.agents[0].occupy, env.agents[1].occupy, env.agents[2].occupy, env.agents[3].occupy],
            #     [env.agents[4].occupy, env.agents[5].occupy, env.agents[6].occupy, env.agents[7].occupy])

        total_steps += steps
        total_episodes += 1
        k = total_episodes % STAT_RT
     
        # Record reward
        episode_rewards[k] = ep_reward
        for i in range(env.n):
            agent_rewards[i][k] = ep_agent_rewards[i]

        # Record benchmark info
        rec_info[k] = info

        # Keep track of final episode reward
        if total_episodes % STAT_RT == 0:

            ### evaluating
            mean_rec_info = round(np.sum(rec_info[-STAT_RT:]), 2)
            summary.add_scalar('mean_episode_info/episode', mean_rec_info, total_episodes)
            if args.benchmark:
                mean_rec_info = (np.round(np.sum(rec_info[-STAT_RT:], axis=0)[0], 2),
                                np.round(np.mean(rec_info[-STAT_RT:], axis=0)[1], 2))
                # summary.add_scalar('mean_episode_info/episode', mean_rec_info, total_episodes)
                logger.info(
                    'Steps: {}, Episodes: {}, Mean episode info: {}'
                    .format(total_steps, total_episodes, mean_rec_info))
                continue
            ###

            mean_episode_reward = round(np.mean(episode_rewards[-STAT_RT:]), 2)  # episode reward of all agents
            final_ep_ag_reward = []  # per agent rewards for training curve
            ep_grp_rew = [[] for _ in range(2)]  # two groups reward
            for i, rew in enumerate(agent_rewards):
                ep_ag_rew = round(np.mean(rew[-STAT_RT:]), 2)
                final_ep_ag_reward.append(ep_ag_rew)
                ###
                summary.add_scalars(
                        "final_episode_agent_reward", 
                        {'agent_%d' % i: ep_ag_rew}, 
                        total_episodes)
                ###
                if env.world.agents[i].tp == 0:
                    ep_grp_rew[0].append(ep_ag_rew)
                else:
                    ep_grp_rew[1].append(ep_ag_rew)
            for i in range(2):
                ep_grp_rew[i] = round(np.sum(ep_grp_rew[i]), 2)
                summary.add_scalars(
                        "final_episode_group_reward", 
                        {'group_%d' % i: ep_grp_rew[i]}, 
                        total_episodes)

            
            use_time = round(time.time() - t_start, 3)
            logger.info(
                'Steps: {}, Episodes: {}, Mean episode reward: {}, mean groups rewards {}, mean agents rewards {}, Info {}, Time: {}'
                .format(total_steps, total_episodes, mean_episode_reward, ep_grp_rew, final_ep_ag_reward, mean_rec_info, use_time))
            t_start = time.time()
            summary.add_scalar('mean_episode_reward/episode', mean_episode_reward, total_episodes)
            summary.add_scalar('mean_episode_reward/step', mean_episode_reward, total_steps)
            summary.add_scalar('use_time/1000episode', use_time, total_episodes)

            # save model
            if not args.restore:
                model_dir = args.model_dir
                os.makedirs(os.path.dirname(model_dir), exist_ok=True)
                for i, agent in enumerate(agents):
                    if i < CT_trner:
                        model_name = '/agent_' + str(i)
                        agents[i].save(model_dir + model_name)
                    


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Environment
    parser.add_argument(
        '--env',
        type=str,
        default='simple_spread',  #
        help='scenario of MultiAgentEnv')
    # auto save model, optional restore model
    parser.add_argument(
        '--model_dir',
        type=str,
        default='./results_maac/model3',
        help='directory for saving model')
    parser.add_argument(
        '--use_pretrained',
        action='store_true',
        default=False,
        help='use pretrained agents or not, must have model_dir')
    parser.add_argument(
        '--restore',
        action='store_true',
        default=False,
        help='restore or not, must have model_dir')
    parser.add_argument(
        '--show', 
        action='store_true', 
        default=False, 
        help='display or not')
    parser.add_argument(
        '--benchmark', 
        action='store_true', 
        default=False, 
        help='benchmark or not')

    args = parser.parse_args()

    if not args.benchmark:
        logger.set_dir('./results_maac/train_log3/' + str(args.env))
    else:
        logger.set_dir('./results_maac/benchmark_log3/' + str(args.env))
    
    train_agent()
