# Code from MARL-code-pytorch repo.
# Refer[Original Code]: https://github.com/Lizhi-sjtu/MARL-code-pytorch/blob/main/4.MADDPG_MATD3_MPE/MADDPG_MATD3_main.py


import torch
import numpy as np

from replay_buffer import ReplayBuffer
from maddpg import MADDPG
import copy
from pettingzoo.mpe import simple_adversary_v3

def trainmaddpg(agent_reward_estimator,adversary_reward_estimator,estimate_agent, estimate_adversary):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = simple_adversary_v3.parallel_env(N=2, max_cycles=25, continuous_actions=True)
    env.reset()
    agent_names= env.agents
    agent_num = len(agent_names)
    total_steps = 0
    obs_dim=[]
    action_dim = []
    for name in agent_names:
        obs_dim.append(env.observation_space(name).shape[0])
        action_dim.append(env.action_space(name).shape[0])
    replay_buffer = ReplayBuffer(agent_num,obs_dim,action_dim)
    agent_n = [MADDPG(agent_num, obs_dim, action_dim,agent_id) for agent_id in range(agent_num)]
    use_noise_decay = True
    noise_std = 0.2
    noise_std_decay = 3e5
    noise_std_min = 0.05
    evaluate_freq= 100
    update_freq = 10

    for ep in range(6000):
        observations,_ = env.reset()
        obs_n = []
        for name in agent_names:
            obs_n.append(observations[name])
        while env.agents:
            # Each agent selects actions based on its own local observations(add noise for exploration)
            a_n = [agent.choose_action(obs, noise_std=noise_std).astype(float) for agent, obs in zip(agent_n, obs_n)]
            # print(a_n)
            # --------------------------!!!注意！！！这里一定要deepcopy，MPE环境会把a_n乘5-------------------------------------------
            a_in = {}
            j = 0
            for name in agent_names:
                a_in[name] = copy.deepcopy(a_n[j])
                a_in[name] = a_in[name].astype('float32')
                j += 1
            # print(a_in)
            observations_next, rewards, dones, _,_ = env.step(a_in)
            obs_next_n = []
            r_n=[]
            done_n=[]
            for name in agent_names:
                obs_next_n.append(observations_next[name])
                r_n.append(rewards[name])
                done_n.append(dones[name])
            if estimate_agent:
                distance1 = np.sqrt(obs_n[1][0]**2+obs_n[1][1]**2)
                distance2 = np.sqrt(obs_n[2][0]**2+obs_n[2][1]**2)
                cost = 0-min(distance1,distance2)
                r_1_input = torch.tensor(obs_n[1], dtype=torch.float, device=device)
                r_n[1] = agent_reward_estimator(r_1_input).detach().cpu().numpy() + cost
                r_2_input = torch.tensor(obs_n[2], dtype=torch.float, device=device)
                r_n[2] = agent_reward_estimator(r_2_input).detach().cpu().numpy() + cost
            if estimate_adversary:
                adversary_input = torch.tensor(obs_n[0], dtype=torch.float, device=device)
                r_n[0] = adversary_reward_estimator(adversary_input).detach().cpu().numpy()
            # Store the transition
            replay_buffer.store_transition(obs_n, a_n, r_n, obs_next_n, done_n)
            obs_n = obs_next_n
            total_steps += 1

            # Decay noise_std
            if use_noise_decay:
                noise_std = noise_std - noise_std_decay if noise_std - noise_std_decay > noise_std_min else noise_std_min

            if replay_buffer.current_size > 1024 and total_steps % update_freq == 0:
                # Train each agent individually
                for agent_id in range(agent_num):
                    agent_n[agent_id].train(replay_buffer, agent_n)
        # if ep % evaluate_freq == 0:
        #     eva_r=evaluate_policy(agent_n)
        #     print("ep:{} \t evaluate_reward:{} ".format(ep, eva_r ))
    env.close()
    dics = []
    for i in range(agent_num):
        agent= agent_n[i]
        # torch.save(agent.actor.state_dict(),"actor"+str(i)+".pt")
        dics.append(copy.deepcopy(agent.actor.state_dict()))
    return dics


def evaluate_policy(agent_n):
    env_eva = simple_adversary_v3.parallel_env(N=2, max_cycles=25, continuous_actions=True)
    env_eva.reset()
    agent_names = env_eva.agents
    evaluate_reward = []
    for i in range(len(agent_names)):
        evaluate_reward.append(0.0)
    for _ in range(10):
        observations,_ = env_eva.reset()
        obs_n=[]
        for name in agent_names:
            obs_n.append(observations[name])
        while env_eva.agents:
            a_n = [agent.choose_action(obs, noise_std=0).astype(float) for agent, obs in zip(agent_n, obs_n)]  # We do not add noise when evaluating
            a_in = {}
            j=0
            for name in agent_names:
                a_in[name] = copy.deepcopy(a_n[j]).astype('float32')
                j += 1
            observations_next, rewards, _, _,_ = env_eva.step(a_in)
            obs_next_n = []
            r_n=[]
            for name in agent_names:
                obs_next_n.append(observations_next[name])
                r_n.append(rewards[name])
            for i in range(len(agent_names)):
                evaluate_reward[i] += r_n[i]
            obs_n = obs_next_n
    for i in range(len(evaluate_reward)):
        evaluate_reward[i] /= 10
    env_eva.close()
    return evaluate_reward


#
# if __name__ == '__main__':
#     trainmaddpg()