import os, sys  
import numpy as np
import gym
import torch
from pathlib import Path
from model import Agent
from buffer import ReplayBuffer
from config import get_config
from model import data_wrap
import matplotlib.pyplot as plt
from src.multiagent_mujoco.mujoco_multi import MujocoMulti
from src.mpe.MPE_env import MPEEnv


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)  #为CPU设置种子用于生成随机数，以使得结果是确定的
    torch.cuda.manual_seed(seed) #为当前GPU设置随机种子；  　
    torch.backends.cudnn.deterministic = True

args = get_config()

if args.cuda and torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    print("choose to use cpu...")
    device = torch.device("cpu")
if args.env_type == 'mujoco':
    env_args = {"scenario": args.scenario,
                  "agent_conf": args.agent_conf,
                  "agent_obsk": args.obsk,
                  "episode_limit": args.max_ep_len}
    print('env_args = {}\nargs = {}'.format(env_args,args))
    env = MujocoMulti(env_args=env_args)
    test_env = MujocoMulti(env_args=env_args)
    env_info = env.get_env_info()
elif args.env_type == 'mpe':
    if args.num_landmarks is None:
        args.num_landmarks = args.num_agents
    args.scenario_name = args.scenario
    env = MPEEnv(args)
    env_info = env.get_env_info()

n_actions = env_info["n_actions"]
n_ant = env_info["n_agents"]
state_space = env_info["obs_shape"]

max_ep_len = args.max_ep_len
max_steps = args.max_steps
gamma = args.gamma


set_seed(args.seed)
env.seed(args.seed)
test_env.seed(args.seed)

method = args.method
batch_size = args.batch_size
capacity = args.capacity
agent_conf = args.agent_conf.split('x')


steps = 0
ep_len = 0
if args.env_type == 'mujoco':
    save_path_list = ['results','{}_{}'.format(args.scenario,args.agent_conf),args.method,args.ex_name]
else:
    save_path_list = ['results', args.scenario, args.method, args.ex_name]
save_path = "."
for p in save_path_list:
    save_path = '{}/{}'.format(save_path,p)
    if not os.path.exists(save_path):
        os.mkdir(save_path)


agents = Agent(state_space,n_actions,n_ant,method,device,args)
buff = ReplayBuffer(capacity,state_space,n_actions,n_ant)

X = np.zeros((batch_size,state_space))
next_X = np.zeros((batch_size,state_space))
A = np.zeros((n_ant,batch_size,n_actions))



def test_agent():
    sum_reward = 0
    for m in range(10):
        o, d, ep_l = test_env.reset(), False, 0
        while not(d or (ep_l == max_ep_len)):
            p = agents.calc_actions(np.array(o))
            # print('p = {}'.format(p.shape))
            for i in range(n_ant):
                p[i] = p[i][0]
            # print('p_after = {}'.format(p.shape))
            r, d, _ = test_env.step(p)
            o = test_env.get_obs()
            sum_reward += r
            ep_l += 1
    return sum_reward/10

save_episode_rewards = []
save_steps = []
save_critic_loss = []
save_policy_loss = []

model_dir = Path(save_path)
if not model_dir.exists():
    run_num = 1
else:
    exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in
                     model_dir.iterdir() if
                     str(folder.name).startswith('run')]
    if len(exst_run_nums) == 0:
        run_num = 1
    else:
        run_num = max(exst_run_nums) + 1
tag_dir = '{}/run{}'.format(save_path,run_num)
if not os.path.exists(tag_dir):
    os.mkdir(tag_dir)

with open('{}/test_interval_{}.txt'.format(save_path,run_num),'w') as f:
    f.write('test_interval = {}'.format(args.test_interval))

obs = env.reset()
critic_loss = None
policy_loss = None
update_cnt = 0
if args.method == 'iql':
    eps_start = args.eps_start
    eps_end = args.eps_end
    eps_step = args.eps_step
    eps_decay_per_step = (eps_start - eps_end) / eps_step
    eps_now = eps_start
while steps<max_steps:

    if steps%args.test_interval==0:
        test_episode_rewards = test_agent()
        print('{}_{} {} {} seed {} epsiode_rewards = {} for steps {}'.format(args.scenario,args.agent_conf,args.method,args.ex_name,args.seed,test_episode_rewards,steps))
        save_episode_rewards.append(test_episode_rewards)
        save_steps.append(steps)
        np.save(save_path + '/t_env_{}'.format(run_num), save_steps)
        np.save(save_path + '/mean_episode_rewards_{}'.format(run_num), save_episode_rewards)
        if args.save_loss:
            np.save(save_path + '/critic_loss_{}'.format(run_num), save_critic_loss)
            np.save(save_path + '/policy_loss_{}'.format(run_num), save_policy_loss)
        if args.auto_draw:
            plt.figure()
            plt.axis([0, args.max_steps, 0, 100])
            plt.cla()

            plt.plot(save_steps, save_episode_rewards)
            plt.xlabel('t_envs')
            plt.ylabel('episode_rewards')

            plt.savefig(save_path + '/mean_episode_rewards_{}.png'.format(run_num), format='png')

            plt.clf()

            plt.plot(np.array(range(update_cnt)), save_critic_loss)
            plt.xlabel('update_num')
            plt.ylabel('critic_loss')
            plt.savefig(save_path + '/plt_{}_critic_loss.png'.format(run_num), format='png')
            plt.clf()

            plt.plot(np.array(range(update_cnt)), save_policy_loss)
            plt.xlabel('update_num')
            plt.ylabel('policy_loss')
            plt.savefig(save_path + '/plt_{}_policy_loss.png'.format(run_num), format='png')
            plt.clf()

    if args.method == 'iddpg':
        p = agents.calc_actions(np.array(obs))
    elif args.method == 'iql':
        p = agents.calc_actions(np.array(obs),eps_now)
    for i in range(n_ant):
        if args.method == 'iddpg':
            if args.new_actor:
                p[i] = p[i][0]
            else:
                if steps < args.warm_up_steps:
                    p[i] = 2*np.random.rand(n_actions) - 1
                else:
                    p[i] = np.clip(p[i][0] + 0.1*np.random.randn(n_actions),-1,1)
        elif args.method == 'iql':
            p[i] = p[i][0]
    reward, terminated, info = env.step(p)
    next_obs = env.get_obs()
    steps += 1
    ep_len += 1
    if args.method == 'iql':
        eps_now = max(eps_now - eps_decay_per_step,eps_end)
    if ep_len == max_ep_len:
        terminated = False
    buff.add(obs, p, reward, next_obs, terminated)
    obs = next_obs

    if (terminated)|(ep_len == max_ep_len):
        obs = env.reset()
        terminated = False
        ep_len = 0



    if (steps < args.sample_interval)|(steps%args.update_interval != 0 ):
        continue

    for e in range(args.mini_batch_num):
        X, A, R, next_X, D = buff.getBatch(batch_size)
        X = data_wrap(X,device)
        A = data_wrap(A,device)
        R = data_wrap(R,device)
        D = data_wrap(D,device)
        next_X = data_wrap(next_X,device)
        Q_target = agents.calc_next_Q(next_X,R,D,gamma)


        critic_loss = agents.train_critic(X, A, Q_target)
        if args.method == 'iddpg':
            policy_loss = agents.train_actors(X)
        agents.update()
        if critic_loss is not None:
            save_critic_loss.append(critic_loss)
        if policy_loss is not None:
            save_policy_loss.append(policy_loss)
        update_cnt += 1
