import random
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import vmas
import os

from CORAPPO_Qsa import MAPPO as CORAPPO_Qsa
from HAPPO import HAPPO as HAPPO
from MAPPO_positive import MAPPO as MAPPO_positive

from MAPPO import MAPPO as MAPPO
import torch
from util import rollout


def train_algorithm(algorithm_class, algorithm_name, repeat_times=5):
    actor_lr = 5e-5
    critic_lr = 5e-4
    num_step = int(5e6)
    hidden_dim = 64
    gamma = 0.99
    epochs = 10
    eps = 0.2
    rollout_step = 100

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    env_device = "cpu"

    env_name = "multi_give_way"
    print(f"Training {algo_name} on {env_name} environment...")
    env = vmas.make_env(
        scenario=env_name,
        num_envs=16,
        device=env_device,
        continuous_actions=False,
        wrapper=None,
        max_steps=rollout_step,
        seed=0,
        dict_spaces=False,
        grad_enabled=False,
        terminated_truncated=False,
    )
    agent_num = len(env.agents)

    sample_size = int(2**agent_num - 2)

    state_dim_list = [env.observation_space[i].shape[0] for i in range(agent_num)]
    action_num_list = [env.action_space[i].n for i in range(agent_num)]

    for repeat in range(repeat_times):
        writer = SummaryWriter(log_dir=f'runs/{algorithm_name}_{env_name}_Repeat_{repeat}')
        np.random.seed(repeat)
        torch.manual_seed(repeat)
        random.seed(repeat)
        env.seed = repeat

        agents = algorithm_class(
            agent_num,
            state_dim_list,
            hidden_dim,
            action_num_list,
            actor_lr,
            critic_lr,
            epochs,
            eps,
            gamma,
            device,
            sample_size=sample_size,
            entropy_soft=True
        )

        return_list = []
        iter = 0
        with tqdm(total=int(num_step), desc=f'{algorithm_name}_{env_name} Repeat {repeat}') as pbar:
            while iter < num_step:
                transition_dict, mean_return, env_step_num = rollout(env, agents, rollout_step)
                iter += env_step_num

                for key in transition_dict:
                    if key in ['states', 'actions', 'next_states', 'rewards', 'dones']:
                        transition_dict[key] = torch.cat(transition_dict[key], dim=0).to(device)

                return_list.append(mean_return.cpu())
                writer.add_scalar('Return', mean_return, iter)

                agents.update(transition_dict)

                pbar.set_postfix({
                    'step': f'{iter}',
                    'return': f'{np.mean(return_list[-10:]):.3f}',
                    'env_step_num': f'{env_step_num:.3f}'
                })
                pbar.update(env_step_num)

        if not os.path.exists('result'):
            os.makedirs('result')
        np.save(f'result/{algorithm_name}_{env_name}_Repeat_{repeat}.npy', return_list)
        writer.close()


if __name__ == "__main__":
    algorithms = {
        "MAPPO": MAPPO,
        "CORAPPO_Qsa": CORAPPO_Qsa,
        "HAPPO": HAPPO,
        "MAPPO_positive": MAPPO_positive,
    }

    results = {}
    for algo_name, algo_class in algorithms.items():
        train_algorithm(algo_class, algo_name, repeat_times=5)
