import random
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import os

import vmas

from CORAPPO_Qsa import MAPPO as CORAPPO_Qsa
from HAPPO import HAPPO as HAPPO
from MAPPO_positive import MAPPO as MAPPO_positive
from MAPPO import MAPPO as MAPPO
from VDPPO import VDPPO as VDPPO
from QMIXPPO import QMIXPPO as QMIXPPO
from COMA import COMA as COMA
from LICA import LICA as LICA

from util import rollout


def train_algorithm(algorithm_class, algorithm_name, repeat_times=5):
    actor_lr = 5e-4
    critic_lr = 5e-3
    num_step = int(5e6)
    hidden_dim = 64
    gamma = 0.99
    epochs = 10
    eps = 0.2
    rollout_step = 100

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    env_device = "cpu"

    env_name = "navigation"
    print(f"Training {algorithm_name} on {env_name} environment...")

    env = vmas.make_env(
        scenario=env_name,
        share_reward="True",
        num_envs=64,
        device=env_device,
        continuous_actions=False,
        wrapper=None,
        max_steps=rollout_step,
        seed=0,
        dict_spaces=False,
        grad_enabled=False,
        terminated_truncated=False,
    )
    agent_num = len(env.agents)
    sample_size = int(2**agent_num - 2)

    state_dim_list = [env.observation_space[i].shape[0] for i in range(agent_num)]
    action_num_list = [env.action_space[i].n for i in range(agent_num)]

    for repeat in range(repeat_times):
        writer = SummaryWriter(log_dir=f'runs/{algorithm_name}_{env_name}_Repeat_{repeat}')
        np.random.seed(repeat)
        torch.manual_seed(repeat)
        random.seed(repeat)
        env.seed = repeat

        agents = algorithm_class(
            agent_num,
            state_dim_list,
            hidden_dim,
            action_num_list,
            actor_lr,
            critic_lr,
            epochs,
            eps,
            gamma,
            device,
            sample_size=sample_size,
            entropy_soft=True
        )

        return_list = []
        iter = 0
        with tqdm(total=int(num_step), desc=f'{algorithm_name}_{env_name} Repeat {repeat}') as pbar:
            while iter < num_step:
                transition_dict, mean_return, env_step_num = rollout(env, agents, rollout_step)
                iter += env_step_num

                for key in transition_dict:
                    if key in ['states', 'actions', 'next_states', 'rewards', 'dones']:
                        transition_dict[key] = torch.cat(transition_dict[key], dim=0).to(device)

                return_list.append(mean_return.cpu())
                writer.add_scalar('Return', mean_return, iter)

                agents.update(transition_dict)

                pbar.set_postfix({
                    'step': f'{iter}',
                    'return': f'{np.mean(return_list[-10:]):.3f}',
                    'env_step_num': f'{env_step_num:.3f}'
                })
                pbar.update(env_step_num)

        if not os.path.exists('result'):
            os.makedirs('result')
        np.save(f'result/{algorithm_name}_{env_name}_Repeat_{repeat}.npy', return_list)
        writer.close()


if __name__ == "__main__":
    algorithms = {
        "MAPPO": MAPPO,
        "CORAPPO_Qsa": CORAPPO_Qsa,
        "HAPPO": HAPPO,
        "MAPPO_positive": MAPPO_positive,  # optimistic mappo
        "VDPPO": VDPPO,
        "QMIXPPO": QMIXPPO,
        "COMA": COMA,
        "LICA": LICA,
    }

    results = {}
    for algo_name, algo_class in algorithms.items():
        train_algorithm(algo_class, algo_name, repeat_times=5)
