#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import argparse
import numpy as np
import torch
from maac_model import MAModel
from maac_agent import MAAgent
from parl.algorithms import MAAC
from parl.env.multiagent_simple_env import MAenv, MTenv
from parl.utils import logger, summary


MIU_LR = 3e-4  # learning rate of the average reward
CRITIC_LR = 3e-4  # learning rate for the critic model
ACTOR_LR = 3e-4  # learning rate of the actor model
GAMMA = 1.  # reward discount factor
ALPHA = 0.  # SoftAC coefficient  #Determines the relative importance of entropy term against the reward

"""
For this Alg. MEMORY_SIZE should be equivalent to BATCHSIZE * EPISODE_LEN
to ensure Actor on-time update"""
MEMORY_SIZE = int(4e3)  #4e3
BATCH_SIZE = 128  #128
TAU = 0.005  # soft update  # smaller tau, slower target_model update

MAX_EPISODES = 20e4  # stop condition:number of episodes
MAX_STEP_PER_EPISODE = 1  # maximum step per episode
MIN_MEMORY_SIZE = BATCH_SIZE * MAX_STEP_PER_EPISODE
STAT_RATE = 1000  # statistical interval of save model or count reward

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


_lr = 0.
def run_episode(env, agents):
    global BATCH_SIZE
    global MAX_STEP_PER_EPISODE
    global _lr
    if args.benchmark or args.show:
        MAX_STEP_PER_EPISODE = 30

    obs_n = env.reset()
    total_reward = 0
    agents_reward = [0 for _ in range(env.n)]
    agent_info = [[] for _ in range(env.n)]  # placeholder for benchmarking info
    group_info = np.zeros(2)
    steps = 0
    while True:
        steps += 1
        # action_n, action_apply_n = [agent.predict(obs) for agent, obs in zip(agents, obs_n)]
        action_n = []
        action_apply_n = []
        for agent, obs in zip(agents, obs_n):
            act_n, act_apply_n = agent.predict(obs)
            action_n.append(act_n)
            action_apply_n.append(act_apply_n)
        next_obs_n, neigh_n, reward_n, done_n, info_n = env.step(action_apply_n)
        done = all(done_n)
        terminal = (steps >= MAX_STEP_PER_EPISODE)

        # store experience
        for i, agent in enumerate(agents):
            agent.add_experience(obs_n[i], action_n[i], neigh_n[i], reward_n[i], next_obs_n[i], done_n[i])

        # compute reward of every agent
        obs_n = next_obs_n
        for i, reward in enumerate(reward_n):
            total_reward += reward
            agents_reward[i] += reward

        # benchmarking learned policies
        if args.benchmark:
            agent_info = agents_reward
            if done or terminal:
                rec_info = np.array(agent_info)
                rec_info = rec_info.reshape(2,-1)
                group_info[0] = np.mean(rec_info[0])
                group_info[1] = np.mean(rec_info[1])
                break
            continue
        
        

        # show animation
        if args.show:
            time.sleep(0.1)
            env.render()

        # show model effect without training
        if args.restore and (args.show or args.benchmark):
            continue

        # ========== learn policy ==========
        #0 Calculate Mu
        # rpm_sample_index = agents[0].rpm.make_index(BATCH_SIZE)
        rpm_sample_index = []
        agents_miu_i = torch.zeros(BATCH_SIZE, env.n)
        for i, agent in enumerate(agents):
            rpm_sample_index.append(agents[i].rpm.make_index(BATCH_SIZE))
            agents_miu_i[:, i] = agent.miu_i
        for i, agent in enumerate(agents):
            _ = agent.calc_miu(rpm_sample_index[i], agents, agents_miu_i)
        
        
        #1 update Induvidual NN parameters
        consensus = True  #True or False
        if consensus and _:
            agg_miu = torch.zeros(BATCH_SIZE)
            m = env.n  # int(env.n / 2)
            for agent in agents:  # [m:]
                agg_miu += agent.miu
            for agent in agents:
                agent.miu = agg_miu / m

        for i, agent in enumerate(agents):
            critic_loss, actor_loss = agent.learn(agents, rpm_sample_index[i])
            if critic_loss != 0.0 and actor_loss != 0.0:
                miu_rec = agent.miu[0].cpu().detach().numpy()
                summary.add_scalar('agent_%d_loss/critic_loss' % i, critic_loss, agent.global_train_step)
                summary.add_scalar('agent_%d_loss/actor_loss' % i, actor_loss, agent.global_train_step)
                summary.add_scalars(
                        "agents_average_reward", 
                        {'agent_%d' % i: miu_rec}, 
                        agent.global_train_step)

        # check the end of an episode
        if done or terminal:
            break

    return total_reward, agents_reward, group_info, steps


def train_agent():
    env = MTenv(args.env, args.benchmark)
    logger.info('agent num: {}'.format(env.n))
    logger.info('observation_space: {}'.format(env.observation_space))
    logger.info('action_space: {}'.format(env.action_space))
    logger.info('obs_shape_n: {}'.format(env.obs_shape_n))
    logger.info('act_shape_n: {}'.format(env.act_shape_n))

    from gym import spaces
    from multiagent.multi_discrete import MultiDiscrete
    # for space in env.action_space:
    #     assert (isinstance(space, spaces.Discrete) or isinstance(space, MultiDiscrete))

    agents = []
    # for i in range(env.n):
    for i, env_agent in enumerate(env.agents):
        model = MAModel(
            env.obs_shape_n[i], 
            env.act_shape_n[i])
        algorithm = MAAC(
            model,
            agent_index=i,
            gamma=GAMMA,
            tau=TAU,
            alpha=ALPHA,
            critic_lr=CRITIC_LR,
            actor_lr=ACTOR_LR)
        agent = MAAgent(
            algorithm,
            agent_index=i,
            obs_dim_n=env.obs_shape_n,
            act_dim_n=env.act_shape_n,
            memory_size=MEMORY_SIZE,
            min_memory_size=MIN_MEMORY_SIZE,
            batch_size=BATCH_SIZE,
            miu_lr=MIU_LR,
            speedup=(not args.restore))
        agents.append(agent)
    total_steps = 0
    total_episodes = 0

    episode_rewards = []  # sum of rewards for all agents
    agent_rewards = [[] for _ in range(env.n)]  # individual agent reward
    rec_info = [[] for _ in range(2)] # record benchmark info

    if args.restore:
        # restore model
        for i in range(len(agents)):
            model_file = args.model_dir + '/agent_' + str(i)
            if not os.path.exists(model_file):
                raise Exception('model file {} does not exits'.format(model_file))
            agents[i].restore(model_file)

    if args.benchmark:
        MAX_EP = 10000
        STAT_RT = 200
    else:
        MAX_EP = MAX_EPISODES
        STAT_RT = STAT_RATE
    
    t_start = time.time()
    logger.info('Starting...')
    while total_episodes <= MAX_EP:
        # run an episode
        ep_reward, ep_agent_rewards, info, steps = run_episode(env, agents)
        # summary.add_scalar('train_reward/episode', ep_reward, total_episodes)
        # summary.add_scalar('train_reward/step', ep_reward, total_steps)
        if args.show:
            m = int(env.n / 2)
            # print('episode {}, reward {}, group reward {}, agents rewards {}, steps {}'.format(
            #     np.round(total_episodes, 2), np.round(ep_reward, 2), 
            #     [np.round(np.mean(ep_agent_rewards[:-m]), 2), np.round(np.mean(ep_agent_rewards[m:]), 2)], 
            #     np.round(ep_agent_rewards, 2), steps))
            # print([env.agents[0].live, env.agents[1].live, env.agents[2].live, env.agents[3].live], 
            #     [env.agents[4].live, env.agents[5].live, env.agents[6].live, env.agents[7].live])
            # print([env.agents[0].occupy, env.agents[1].occupy, env.agents[2].occupy, env.agents[3].occupy],
            #     [env.agents[4].occupy, env.agents[5].occupy, env.agents[6].occupy, env.agents[7].occupy])

        total_steps += steps
        total_episodes += 1
     
        # Record reward
        episode_rewards.append(ep_reward)
        for i in range(env.n):
            agent_rewards[i].append(ep_agent_rewards[i])

        # Record benchmark info
        for i in range(2):
            rec_info[i].append(info[i])

        # Keep track of final episode reward
        if total_episodes % STAT_RT == 0:
            
            ### evaluating
            # if args.benchmark:
            #     mean_rec_info = np.zeros(2)
            #     for i in range(2):
            #         mean_rec_info[i] = round(np.mean(rec_info[i][-STAT_RT:]), 2)
            #     for i in range(2):
            #         summary.add_scalars(
            #                 "final_episode_group_reward", 
            #                 {'group_%d' % i: mean_rec_info[i]}, 
            #                 total_episodes)
            #     logger.info(
            #         'Steps: {}, Episodes: {}, Mean episode group reward: {}'
            #         .format(total_steps, total_episodes, mean_rec_info))
            #     continue
            ###

            mean_episode_reward = round(np.mean(episode_rewards[-STAT_RT:]), 2)  # episode reward of all agents
            final_ep_ag_reward = []  # per agent rewards for training curve
            ep_grp_rew = [[] for _ in range(2)]  # two groups reward
            for i, rew in enumerate(agent_rewards):
                ep_ag_rew = round(np.mean(rew[-STAT_RT:]), 2)
                final_ep_ag_reward.append(ep_ag_rew)
                ###
                summary.add_scalars(
                        "final_episode_agent_reward", 
                        {'agent_%d' % i: ep_ag_rew}, 
                        total_episodes)
                ###
                if i < (env.n / 2):
                    ep_grp_rew[0].append(ep_ag_rew)
                else:
                    ep_grp_rew[1].append(ep_ag_rew)
            for i in range(2):
                ep_grp_rew[i] = round(np.mean(ep_grp_rew[i]), 2)
                summary.add_scalars(
                        "final_episode_group_reward", 
                        {'group_%d' % i: ep_grp_rew[i]}, 
                        total_episodes)

            ###
            if args.benchmark:
                logger.info(
                    'Episode: {}, Grp reward {}, Ind reward {}'
                    .format(total_episodes, ep_grp_rew, final_ep_ag_reward))
                continue
            ###
            
            use_time = round(time.time() - t_start, 3)
            logger.info(
                'Steps: {}, Episodes: {}, Mean episode reward: {}, mean groups rewards {}, mean agents rewards {}, Time: {}'
                .format(total_steps, total_episodes, mean_episode_reward, ep_grp_rew, final_ep_ag_reward, use_time))
            t_start = time.time()
            summary.add_scalar('mean_episode_reward/episode', mean_episode_reward, total_episodes)
            summary.add_scalar('mean_episode_reward/step', mean_episode_reward, total_steps)
            summary.add_scalar('use_time/1000episode', use_time, total_episodes)

            # save model
            if not args.restore:
                model_dir = args.model_dir
                os.makedirs(os.path.dirname(model_dir), exist_ok=True)
                for i in range(len(agents)):
                    model_name = '/agent_' + str(i)
                    agents[i].save(model_dir + model_name)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Environment
    parser.add_argument(
        '--env',
        type=str,
        default='simple_spread',  #
        help='scenario of MultiAgentEnv')
    # auto save model, optional restore model
    parser.add_argument(
        '--model_dir',
        type=str,
        default='./results_maac/model3',
        help='directory for saving model')
    parser.add_argument(
        '--restore',
        action='store_true',
        default=False,
        help='restore or not, must have model_dir')
    parser.add_argument(
        '--show', 
        action='store_true', 
        default=False, 
        help='display or not')
    parser.add_argument(
        '--benchmark', 
        action='store_true', 
        default=False, 
        help='benchmark or not')

    args = parser.parse_args()

    if not args.benchmark:
        logger.set_dir('./results_maac/train_log3/' + str(args.env))
    else:
        logger.set_dir('./results_maac/benchmark_log3/' + str(args.env))
    
    train_agent()
