import random
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from env_matrix import make_env
import os

from CORAPPO_Qsa import MAPPO as CORAPPO_Qsa
from MAPPO_positive import MAPPO as MAPPO_positive
from MAAC import MAAC as MAAC
from HAPPO import HAPPO as HAPPO
from COMA import COMA as COMA
from MAPPO import MAPPO as MAPPO
import torch
from util import rollout

def train_algorithm(algorithm_class, algorithm_name, repeat_times=5):
    actor_lr = 5e-5
    critic_lr = 5e-4
    num_step = int(2e4)
    hidden_dim = 64
    gamma = 0.99
    epochs = 10
    eps = 0.3
    rollout_step = 10

    device = torch.device("cpu")
    env_device = "cpu"

    env_name = "matrix_game"
    print(f"Training {algorithm_name} on {env_name} environment...")

    env = make_env(
        scenario=env_name,
        num_envs=4,
        n_agents=4,
        num_actions=5,
        seed=0,
        num_stages=rollout_step,
    )
    agent_num = len(env.agents)

    sample_size = int(2 ** agent_num - 2)
    state_dim_list = [env.observation_space[i].shape[0] for i in range(agent_num)]
    action_num_list = [env.action_space[i].n for i in range(agent_num)]

    for repeat in range(repeat_times):
        writer = SummaryWriter(log_dir=f'runs/{algorithm_name}_{env_name}_Repeat_{repeat}')
        np.random.seed(repeat)
        torch.manual_seed(repeat)
        random.seed(repeat)
        env.seed = repeat

        agents = algorithm_class(
            agent_num,
            state_dim_list,
            hidden_dim,
            action_num_list,
            actor_lr,
            critic_lr,
            epochs,
            eps,
            gamma,
            device,
            sample_size=sample_size,
            entropy_soft=True
        )

        return_list = []
        iter = 0
        with tqdm(total=int(num_step), desc=f'{algorithm_name}_{env_name} Repeat {repeat}') as pbar:
            while iter < num_step:
                transition_dict, mean_return, env_step_num = rollout(env, agents, rollout_step)
                iter += env_step_num

                for key in transition_dict:
                    if key in ['states', 'actions', 'next_states', 'rewards', 'dones']:
                        transition_dict[key] = torch.cat(transition_dict[key], dim=0).to(device)

                return_list.append(mean_return.cpu())
                writer.add_scalar('Return', mean_return, iter)

                agents.update(transition_dict)

                pbar.set_postfix({
                    'step': f'{iter}',
                    'return': f'{np.mean(return_list[-10:]):.3f}',
                    'env_step_num': f'{env_step_num:.3f}'
                })
                pbar.update(env_step_num)

        if not os.path.exists('result'):
            os.makedirs('result')
        np.save(f'result/{algorithm_name}_{env_name}_Repeat_{repeat}.npy', return_list)
        writer.close()

from VDPPO import VDPPO as VDPPO
from QMIXPPO import QMIXPPO as QMIXPPO
from LICA import LICA as LICA


if __name__ == "__main__":
    algorithms = {
        "MAPPO": MAPPO,
        "CORAPPO_Qsa": CORAPPO_Qsa,

        "MAPPO_positive": MAPPO_positive,
        "MAAC": MAAC,
        "HAPPO": HAPPO,
        "COMA": COMA,
        "VDPPO": VDPPO,
        "QMIXPPO": QMIXPPO,
        "LICA": LICA,
    }

    results = {}
    for algo_name, algo_class in algorithms.items():
        train_algorithm(algo_class, algo_name, repeat_times=5)
