import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from PPOContinuous import PPOContinuous as PPOContinuous
from HAPPOContinuous import HAPPOContinuous as HAPPOContinuous
from COREPPOContinuous import PPOContinuous as COREPPOContinuous
from OptiPPOContinuous import PPOContinuous as OptiPPOContinuous
from COREPPOContinuous_wostd import PPOContinuous as COREPPOContinuous_wostd
from tqdm import tqdm
import os
import random
from util import rollout_async, evaluate_dg
from env_wrapper import EnvWrapper
from async_env_wrapper2 import ParallelEnvWrapper
from env_SingleStepDG import DifferentialGameEnv


def train_algorithm(algorithm_class, algorithm_name, repeat_times=1):
    actor_lr = 5e-5
    critic_lr = 5e-4
    num_step = int(1e2)  # 1e5
    hidden_dim = 64
    gamma = 0.99
    lmbda = 0.95
    epochs = 10
    eps = 0.2
    # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    device = torch.device("cpu")

    # env_name = "Ant"
    # setting = "2x4"
    env_name = "DifferentialGame"
    setting = "2x1"
    env_num = 4

    np.random.seed(2)
    torch.manual_seed(2)
    random.seed(2)

    # env = gymnasium_robotics.mamujoco_v1.parallel_env(env_name, setting)
    # env = EnvWrapper(env, seed=0)
    env = DifferentialGameEnv()
    bound = 5.0
    max_ep_len = env.horizon

    agent_num = env.agent_num
    obs_dim = env.obs_dim
    action_dim = env.action_dim
    if agent_num == 2:
        sample_size = 2
    else:
        # sample_size = int((2**agent_num - 2)/2)
        sample_size = int(2 ** agent_num - 2)
    for repeat in range(repeat_times):

        writer = SummaryWriter(log_dir=f'runs/{algorithm_name}_{env_name}_{setting}_EnvNum_{env_num}_Async_Repeat_{repeat}')
        envs = ParallelEnvWrapper(lambda: env, env_num, seed=0)
        # set random seed
        agents = algorithm_class(
            agent_num,
            obs_dim,
            hidden_dim,
            action_dim,
            actor_lr,
            critic_lr,
            lmbda,
            epochs,
            eps,
            gamma,
            sample_size,
            bound,
            device
        )

        return_list = []
        point_list = []
        iter = 0
        with tqdm(total=int(num_step), desc=f'{algorithm_name}_{env_name}_{setting}_EnvNum_{env_num}_Async_Repeat_{repeat}') as pbar:
            while iter < num_step:
                transition_dict, env_step_num = rollout_async(envs, agents, max_ep_len)
                eval_episode_return, point = evaluate_dg(env, agents)
                point_list.append(point)

                iter += env_step_num
                return_list.append(eval_episode_return)
                writer.add_scalar('Return', eval_episode_return, iter)

                for key in transition_dict:
                    transition_dict[key] = torch.tensor(np.array(transition_dict[key]), dtype=torch.float).to(device)
                agents.update(transition_dict)

                pbar.set_postfix({
                    'step': '%d' % iter,
                    'return': '%.3f' % np.mean(return_list[-10:])
                })
                pbar.update(env_step_num)

        envs.close()
        writer.close()

        if not os.path.exists('models'):
            os.makedirs('models')
        for i, actor in enumerate(agents.actors):
            torch.save(
                actor.state_dict(),
                f'models/{algorithm_name}_{env_name}_{setting}_EnvNum_{env_num}_Repeat_{repeat}_agent_{i}.pth'
            )

        if not os.path.exists('result'):
            os.makedirs('result')
        np.save(f'result/{algorithm_name}_{env_name}_{setting}_EnvNum_{env_num}_Async_Repeat_{repeat}.npy', return_list)
        np.save(f'result/{algorithm_name}_{env_name}_{setting}_EnvNum_{env_num}_Async_Repeat_{repeat}_points.npy',
                np.array(point_list))

algorithms = {
    "OptiPPOContinuous": OptiPPOContinuous,
    "COREPPOContinuous_wostd": COREPPOContinuous_wostd,
    "PPOContinuous": PPOContinuous,
    "COREPPOContinuous": COREPPOContinuous,
    "HAPPOContinuous": HAPPOContinuous,
}

if __name__ == "__main__":
    results = {}
    for algo_name, algo_class in algorithms.items():
        print(f"Training {algo_name}...")
        train_algorithm(algo_class, algo_name, repeat_times=5)