from omegaconf import DictConfig, OmegaConf
import hydra

import numpy as np
import torch
import random
import os

from envs.gridworld import GridWorldEnv
from envs.appledoor import AppleDoorEnv
from envs.mpe.environment import MPEEnv
from algos.ppo import PPO
from algos.ppo_wprior import PPOwPrior
from algos.dm2 import DM2
from algos.pegmarl import PegMARL
from algos.base import ExpBuffer
import utils


@hydra.main(version_base=None, config_path="config", config_name="train")
def main(args: DictConfig) -> None:
    print(OmegaConf.to_yaml(args))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # seed
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    # setup environment
    if "centerSquare" in args.env.env_name:
        env = GridWorldEnv(args.env.env_name, seed=args.seed, dense_reward=args.env.dense_reward)
    elif "appleDoor" in args.env.env_name:
        env = AppleDoorEnv(args.env.env_name, seed=args.seed, dense_reward=args.env.dense_reward)
    elif "mpe" in args.env.env_name:
        env = MPEEnv(args.env)
    else:
        raise ValueError("Invalid environment name: {}".format(args.env))
    
    obs_dim = env.observation_space[0].shape[0]
    action_dim = env.action_space[0].n
    agent_num = env.agent_num

    # setup logging directory
    model_dir = utils.get_model_dir_name(args)
    print("Model save at: ", model_dir)
    os.makedirs(model_dir, exist_ok=True)
    # save config file
    with open(model_dir + "/config.yaml", "w") as f:
        f.write(OmegaConf.to_yaml(args))

    max_len = args.buffer_size
    # setup algorithms
    if args.algo == "PPO":
        algo = PPO(env, args, target_steps=max_len)
    elif args.algo == "PPOwPrior":
        prior = utils.load_prior(args)
        algo = PPOwPrior(env, args, target_steps=max_len, prior=prior)
    elif args.algo == "MAGAIL":
        expert_traj = utils.load_expert_trajectory(args)
        algo = DM2(env, args, expert_traj, target_steps=max_len)
    elif args.algo == "DM2":
        expert_traj = utils.load_expert_trajectory(args)
        algo = DM2(env, args, expert_traj, target_steps=max_len)
    elif args.algo == "PegMARL":
        expert_traj = utils.load_expert_trajectory(args)
        algo = PegMARL(env, args, expert_traj, target_steps=max_len)
    else:
        raise ValueError("Incorrect algorithm name: {}".format(args.algo))

    max_len += env.max_steps
    buffer = ExpBuffer(max_len, obs_dim, action_dim, agent_num, args)
    tb_writer = utils.tb_writer(model_dir, agent_num, args.algo=="PPOwPrior")

    # try to load existing models
    try:
        status = torch.load(model_dir + "/last_status.pt", map_location=device)
        algo.load_status(status)
        update = status["update"]
        num_frames = status["num_frames"]
        tb_writer.ep_num = status["num_episode"]
        best_return = status["best_return"]
        if args.use_shadow_reward:
            algo.pweight = status["pweight"]
    except OSError:
        update = 0
        num_frames = 0
        best_return = -999999

    # start to train
    while num_frames < args.frames:
        frames = algo.collect_experiences(buffer, tb_writer)
        algo.update_parameters(buffer, tb_writer)
        num_frames += frames
        avg_returns = tb_writer.log(num_frames)

        update += 1
        if update % 1 == 0:
            tb_writer.log_csv()
            tb_writer.empty_buffer()
            status = {"num_frames": num_frames, "update": update,
                    "num_episode": tb_writer.ep_num, "best_return": best_return,
                    "model_state": [acmodel.state_dict() for acmodel in algo.acmodels],
                    "optimizer_state": [optimizer.state_dict() for optimizer in algo.optimizers]}
            if args.algo == "DM2":
                status["discriminator_state"] = [discrim.state_dict() for discrim in algo.discriminators]
                status["d_optimizer_state"] = [optimizer.state_dict() for optimizer in algo.d_optimizers]
            if args.algo == "PegMARL":
                status["dyn_discriminator_state"] = [discrim.state_dict() for discrim in algo.dyn_discriminators]
                status["dyn_optimizer_state"] = [optimizer.state_dict() for optimizer in algo.dyn_optimizers]
            if args.use_shadow_reward:
                status["pweight"] = algo.pweight
            torch.save(status, model_dir + "/last_status.pt")
            if args.save_interval and update % 100 == 0:
                torch.save(status, model_dir + "/status_" + str(update) + "_" + str(num_frames) + ".pt")
            if np.all(avg_returns > best_return):
                best_return = avg_returns.copy()
                torch.save(status, model_dir + "/best_status.pt")


if __name__ == "__main__":
    main()