import copy
import random
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from env_matrix_multiopt import make_env
import os

from CORAPPO_Qsa import MAPPO as CORAPPO_Qsa
from MAPPO_positive import MAPPO as MAPPO_positive
from MAAC import MAAC as MAAC
from HAPPO import HAPPO as HAPPO
from COMA import COMA as COMA
from MAPPO import MAPPO as MAPPO
import torch
from util import rollout


def train_algorithm(algorithm_class, algorithm_name, repeat_times=5):
    actor_lr = 5e-5
    critic_lr = 5e-4
    num_step = int(2e4)
    hidden_dim = 64
    gamma = 0.99
    epochs = 10
    eps = 0.3
    rollout_step = 10
    n_peaks = 10  # 5, 10, 15

    device = torch.device("cpu")
    env_device = "cpu"

    env_name = "matrix_game_multiopt"
    print(f"Training {algorithm_name} on {env_name} environment...")

    env = make_env(
        scenario=env_name,
        num_envs=4,
        n_agents=3,
        num_actions=4,
        n_peaks=n_peaks,
        peak_base=15,
        seed=0,
        num_stages=rollout_step,
        share_reward="True",
    )
    agent_num = len(env.agents)

    sample_size = int(2 ** agent_num - 2)
    state_dim_list = [env.observation_space[i].shape[0] for i in range(agent_num)]
    action_num_list = [env.action_space[i].n for i in range(agent_num)]

    for repeat in range(repeat_times):
        writer = SummaryWriter(log_dir=f'runs/{algorithm_name}_{env_name}_n_peaks={n_peaks}_Repeat_{repeat}')
        np.random.seed(repeat)
        torch.manual_seed(repeat)
        random.seed(repeat)
        env.seed = repeat

        agents = algorithm_class(
            agent_num,
            state_dim_list,
            hidden_dim,
            action_num_list,
            actor_lr,
            critic_lr,
            epochs,
            eps,
            gamma,
            device,
            sample_size=sample_size,
            entropy_soft=True
        )

        return_list = []
        iter = 0
        with tqdm(total=int(num_step), desc=f'{algorithm_name}_{env_name}_n_peaks={n_peaks} Repeat {repeat}') as pbar:
            while iter < num_step:
                transition_dict, mean_return, env_step_num = rollout(env, agents, rollout_step)
                iter += env_step_num

                for key in transition_dict:
                    if key in ['states', 'actions', 'next_states', 'rewards', 'dones']:
                        transition_dict[key] = torch.cat(transition_dict[key], dim=0).to(device)

                return_list.append(mean_return.cpu())
                writer.add_scalar('Return', mean_return, iter)

                agents.update(transition_dict)

                pbar.set_postfix({
                    'step': f'{iter}',
                    'return': f'{np.mean(return_list[-10:]):.3f}',
                    'env_step_num': f'{env_step_num:.3f}'
                })
                pbar.update(env_step_num)

        if not os.path.exists('result'):
            os.makedirs('result')
        np.save(f'result/{algorithm_name}_{env_name}_n_peaks={n_peaks}_Repeat_{repeat}.npy', return_list)
        writer.close()

from VDPPO import VDPPO as VDPPO
from QMIXPPO import QMIXPPO as QMIXPPO
from LICA import LICA as LICA

if __name__ == "__main__":
    algorithms = {
        "MAAC": MAAC,
        "MAPPO": MAPPO,
        "CORAPPO_Qsa": CORAPPO_Qsa,
        "MAPPO_positive": MAPPO_positive,
        "HAPPO": HAPPO,
        "COMA": COMA,
        "VDPPO": VDPPO,
        "QMIXPPO": QMIXPPO,
        "LICA": LICA,
    }

    results = {}
    for algo_name, algo_class in algorithms.items():
        train_algorithm(algo_class, algo_name, repeat_times=5)
