import gymnasium as gym
import gymnasium_robotics

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from PPOContinuous import PPOContinuous as PPOContinuous
from COREPPOContinuous_Qsa import PPOContinuous as COREPPOContinuous_Qsa
from COREPPOContinuous import PPOContinuous as COREPPOContinuous
from HAPPOContinuous import HAPPOContinuous as HAPPOContinuous
from COREPPOContinuous_wostd import PPOContinuous as COREPPOContinuous_wostd
from tqdm import tqdm
import os
import random
from util import rollout_async, evaluate
from env_wrapper import EnvWrapper
from async_env_wrapper2 import ParallelEnvWrapper


def train_algorithm(algorithm_class, algorithm_name, repeat_times=1):
    actor_lr = 5e-4
    critic_lr = 5e-3
    num_step = int(5e4)
    hidden_dim = 64
    gamma = 0.99
    lmbda = 0.95
    epochs = 10
    eps = 0.3
    # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    device = torch.device("cpu")

    env_name = "Ant"
    setting = "2x4"
    env_num = 4

    np.random.seed(0)
    torch.manual_seed(0)
    random.seed(0)

    env = gymnasium_robotics.mamujoco_v1.parallel_env(env_name, setting)
    env = EnvWrapper(env, seed=0)
    bound = 1.0
    max_ep_len = int(1e3)

    agent_num = env.agent_num
    obs_dim = env.obs_dim
    action_dim = env.action_dim
    if agent_num == 2:
        sample_size = 2
    else:
        # sample_size = int((2**agent_num - 2)/2)
        sample_size = int(2 ** agent_num - 2)
    for repeat in range(repeat_times):

        writer = SummaryWriter(log_dir=f'runs/{algorithm_name}_{env_name}_{setting}_EnvNum_{env_num}_Async_Repeat_{repeat}')
        envs = ParallelEnvWrapper(lambda: env, env_num, seed=0)
        # set random seed
        agents = algorithm_class(
            agent_num,
            obs_dim,
            hidden_dim,
            action_dim,
            actor_lr,
            critic_lr,
            lmbda,
            epochs,
            eps,
            gamma,
            sample_size,
            bound,
            device
        )

        return_list = []
        iter = 0
        with tqdm(total=int(num_step), desc=f'{algorithm_name}_{env_name}_{setting}_EnvNum_{env_num}_Async_Repeat_{repeat}') as pbar:
            while iter < num_step:
                transition_dict, env_step_num = rollout_async(envs, agents, max_ep_len)
                eval_episode_return = evaluate(env, agents)

                iter += env_step_num
                return_list.append(eval_episode_return)
                writer.add_scalar('Return', eval_episode_return, iter)

                for key in transition_dict:
                    transition_dict[key] = torch.tensor(np.array(transition_dict[key]), dtype=torch.float).to(device)
                agents.update(transition_dict)

                pbar.set_postfix({
                    'step': '%d' % iter,
                    'return': '%.3f' % np.mean(return_list[-10:])
                })
                pbar.update(env_step_num)

        envs.close()
        writer.close()
        if not os.path.exists('result'):
            os.makedirs('result')
        np.save(f'result/{algorithm_name}_{env_name}_{setting}_EnvNum_{env_num}_Async_Repeat_{repeat}.npy', return_list)

algorithms = {
    "COREPPOContinuous_wostd": COREPPOContinuous_wostd,
    "COREPPOContinuous": COREPPOContinuous,
}

if __name__ == "__main__":
    results = {}
    for algo_name, algo_class in algorithms.items():
        print(f"Training {algo_name}...")
        train_algorithm(algo_class, algo_name, repeat_times=5)