import os
import socket
import time
from pathlib import Path
import wandb
import argparse
import numpy as np
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

from xuance import get_arguments
from xuance.common import get_configs, recursive_dict_update, get_time_string
from xuance.torch.utils.operations import set_seed
from xuance.environment import make_envs
from xuance.environment import REGISTRY_ENV, REGISTRY_MULTI_AGENT_ENV, REGISTRY_VEC_ENV

from marl4rna.env.rna_environment_multi import MultiAgentRNAEnv
from marl4rna.agents.mappo_agents import MAPPO_Agents
from marl4rna.env.subproc_vec_maenv import SubprocVecMultiAgentEnv_RNA
from marl4rna.env.dummy_vec_maenv import DummyVecMultiAgentEnv_RNA

REGISTRY_MULTI_AGENT_ENV["MultiAgentRNAEnv"] = MultiAgentRNAEnv
REGISTRY_VEC_ENV["SubprocVecMultiAgentEnv_RNA"] = SubprocVecMultiAgentEnv_RNA
REGISTRY_VEC_ENV["DummyVecMultiAgentEnv_RNA"] = DummyVecMultiAgentEnv_RNA


def parse_args():
    parser = argparse.ArgumentParser("MARL example for RNA")
    parser.add_argument("--env", type=str, default="MultiAgentRNAEnv")
    parser.add_argument("--test", type=int, default=1)
    parser.add_argument("--benchmark", type=int, default=0)
    parser.add_argument("--config", type=str, default=f"rna_config_mappo.yaml")

    return parser.parse_args()


if __name__ == "__main__":
    parser = parse_args()
    configs_dict = get_configs(file_dir="rna_config_mappo.yaml")
    configs_dict = recursive_dict_update(configs_dict, parser.__dict__)
    configs = argparse.Namespace(**configs_dict)

    set_seed(configs.seed)
    envs = make_envs(configs)

    configs_test = deepcopy(configs)
    configs_test.parallels = 1
    configs_test.vectorize = "DummyVecMultiAgentEnv_RNA"
    configs_test.test = 1
    env_fn = make_envs(configs_test)

    Agent = MAPPO_Agents(config=configs, envs=envs)

    train_information = {"Deep learning toolbox": configs.dl_toolbox,
                         "Calculating device": configs.device,
                         "Algorithm": configs.agent,
                         "Environment": configs.env_name,
                         "Scenario": configs.env_id}
    for k, v in train_information.items():
        print(f"{k}: {v}")

    if configs.benchmark:
        train_steps = configs.running_steps // configs.parallels
        eval_interval = configs.eval_interval // configs.parallels
        test_episode = configs.test_episode
        num_epoch = int(train_steps / eval_interval)

        test_scores = Agent.test(env_fn, test_episode)
        Agent.save_model(model_name="best_model.pth")
        best_scores_info = {"mean": np.mean(test_scores),
                            "std": np.std(test_scores),
                            "step": Agent.current_step}
        for i_epoch in range(num_epoch):
            print("Epoch: %d/%d:" % (i_epoch, num_epoch))
            Agent.train(eval_interval)
            test_scores = Agent.test(env_fn, test_episode)

            if np.mean(test_scores) > best_scores_info["mean"]:
                best_scores_info = {"mean": np.mean(test_scores),
                                    "std": np.std(test_scores),
                                    "step": Agent.current_step}
                Agent.save_model(model_name="best_model.pth")
            if i_epoch % 10 == 0:
                Agent.save_model(model_name="final_model.pth")
        Agent.save_model(model_name="final_model.pth")
        env_fn.close()
        print("Best Model Score: %.2f, std=%.2f" % (best_scores_info["mean"], best_scores_info["std"]))
    else:
        if configs.test:
            Agent.load_model(path=Agent.model_dir_load)
            scores = Agent.test(env_fn, 650)
            print(f"Mean Score: {np.mean(scores)}, Std: {np.std(scores)}")
            print("Finish testing.")
        else:
            Agent.train(configs.running_steps // configs.parallels)
            Agent.save_model("final_train_model.pth")
            print("Finish training!")

    Agent.finish()






