import numpy as np
# import gymnasium as gym
import gym

import os, sys
from arguments import get_args
from mpi4py import MPI
# from rl_modules.ddpg_agent_panda import ddpg_agent
from rl_modules.ddpg_agent_panda_0 import ddpg_agent_0
from rl_modules.ddpg_agent_panda_1 import ddpg_agent_1
from rl_modules.ddpg_agent_panda_2 import ddpg_agent_2
from rl_modules.ddpg_agent_panda_3 import ddpg_agent_3

import random
import torch
import panda_gym
from dfa import DFA

"""
train the agent, the MPI part code is copy from openai baselines(https://github.com/openai/baselines/blob/master/baselines/her)

"""
def get_env_params(env):
    observation = env.reset()

    # close the environment
    params = {'obs': observation["observation"].shape[0],
            'goal': observation["desired_goal"].shape[0],
            'action': env.action_space.shape[0],
            'action_max': env.action_space.high[0],
            }
    params['max_timesteps'] = 100
    return params

def make_envs():
    env_list_names = ['PandaPush-v2', 'PandaStackSmallonBig-v2', 'PandaPickAndPlace-v2', 'PandaStackSmallBigonTarget-v2']
    envs = []
    env_params = []
    for env_name in env_list_names:
        env = gym.make(env_name)
        envs.append(env)
    for env in envs:
        observation = env.reset()
        # close the environment
        params = {'obs': observation["observation"].shape[0],
                'goal': observation["desired_goal"].shape[0],
                'action': env.action_space.shape[0],
                'action_max': env.action_space.high[0],
                }
        params['max_timesteps'] = 100
        env_params.append(params)        
    return envs, env_params, env_list_names

# def make_agents(args, envs, params, env_list_names):
#     ddpg_trainers = []
#     for i in range(len(envs)):
#         ddpg_trainers.append(ddpg_agent(args, envs[i], params[i], env_list_names[i]))
#     return ddpg_trainers

def make_agents(args, envs, params, env_list_names):
    ddpg_trainers = []
    ddpg_trainers.append(ddpg_agent_0(args, envs[0], params[0], env_list_names[0]))
    ddpg_trainers.append(ddpg_agent_1(args, envs[1], params[1], env_list_names[1]))
    ddpg_trainers.append(ddpg_agent_2(args, envs[2], params[2], env_list_names[2]))
    ddpg_trainers.append(ddpg_agent_3(args, envs[3], params[3], env_list_names[3]))
    return ddpg_trainers


def launch(args):
    # create the ddpg_agent
    envs, env_params, env_list_names = make_envs()

    # set random seeds for reproduce
    # env.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    np.random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    torch.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank())

    if args.cuda:
        torch.cuda.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank())
    # get the environment parameters

    # env_params = get_env_params(env)
    # create the ddpg agent to interact with the environment 
    ddpg_trainers = make_agents(args, envs, env_params, env_list_names)
    dfa_instance = DFA()    
    epochs_per_task = [0 for i in range(len(envs))]
    total_success_rate = [[] for i in range(len(envs))]
    while True:
        current_task = dfa_instance.choose_task()
        env = envs[current_task]
        ddpg_trainer = ddpg_trainers[current_task]
        success_rate_arr = ddpg_trainer.learn(current_task)
        epochs_per_task[current_task] += 10
        total_success_rate[current_task].extend(success_rate_arr)
        is_task_leared = False
        is_final_task = 0
        if np.mean(success_rate_arr) > 0.85:
            is_final_task = dfa_instance.learned_task(current_task)
            is_task_leared =True

        if not is_task_leared:
            dfa_instance.update_teacher(current_task, np.mean(success_rate_arr))        
            print(total_success_rate)
            print(epochs_per_task)
        if is_final_task == 1:
            print(total_success_rate)
            print(epochs_per_task)
            break


if __name__ == '__main__':
    # take the configuration for the HER
    os.environ['OMP_NUM_THREADS'] = '1'
    os.environ['MKL_NUM_THREADS'] = '1'
    os.environ['IN_MPI'] = '1'
    # get the params
    args = get_args()
    launch(args)
