import numpy as np
import gym
import os
from config.arguments import get_args as get_args
from config.arguments2 import get_args as get_args2
from config.arguments3 import get_args as get_args3
from config.arguments4 import get_args as get_args4
from config.arguments7 import get_args as get_args7
from mpi4py import MPI
from rl_modules.ddpg_CounterHER import ddpg_CounterHER
from rl_modules.ddpg_ConsHER import ddpg_ConsHER
from rl_modules.ddpg_HERIGA import ddpg_HERIGA
from config.arguments5 import get_args as get_args5
import random
import torch
# import  multiworld
# multiworld.register_all_envs()
import  multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from pointEnv.PointEnv import PointEnv
import time
from pointEnv.PointEnv import GymPointEnv
# from pointEnv.MultiGoalEnvSingle import MultiGoalEnvSingle

import gym_fetch_stack

from goal_env import *
from goal_env.mujoco import *

# from pointmass import PointmassEnv
"""
train the agent, the MPI part code is copy from openai baselines(https://github.com/openai/baselines/blob/master/baselines/her)

"""
def get_env_params(env_name, env, testenv=None):
    obs = env.reset()
    if testenv!=None:
        obstest = testenv.reset()
    # close the environment
    params = {'obs': obs['observation'].shape[0],
            'goal': obs['desired_goal'].shape[0],
            'action': env.action_space.shape[0],
            'action_max': env.action_space.high[0],
            'l1': 16,
            'l2': 128,
            'env_name': env_name,
            }
    print(env_name,params)
    if env_name[:6]=='Sawyer' or env_name=='MultiGoal':
        params['max_timesteps'] = 50
    elif env_name[:3]=='Ant':
        params['max_timesteps'] =  env._max_episode_steps
        params['max_timesteps2'] =  testenv._max_episode_steps
    else:
        params['max_timesteps'] = env._max_episode_steps
    return params

    
def launch(args,tag):
    if args.env_name.startswith('Point'):
        env = GymPointEnv(args.env_name.split('Point')[1], max_episode_steps=25, resize_factor=1)
        testenv = env
        difficulty = 0.75
        # difficulty = 0.85
        max_goal_dist = env.max_goal_dist
        env.set_sample_goal_args(
            prob_constraint=1.0,
            min_dist=max(0, max_goal_dist * (difficulty - 0.05)),
            max_dist=max_goal_dist * (difficulty + 0.05))
    elif args.env_name[:3]=='Ant':
        maze_size_scaling = 3
        env = gym.make(args.env_name,maze_size_scaling = 4)
        testenv = gym.make(args.env_name.replace('-','Test-'),maze_size_scaling = 4)
    elif args.env_name=='MultiGoal':
        env = MultiGoalEnvSingle()
        testenv = MultiGoalEnvSingle()
    else:
        env = gym.make(args.env_name)
        testenv = gym.make(args.env_name)

    # env.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    np.random.seed(args.seed + MPI.COMM_WORLD.Get_rank())
    torch.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank())
    if args.cuda:
        torch.cuda.manual_seed(args.seed + MPI.COMM_WORLD.Get_rank())
    # get the environment parameters
    alg = tag.split('_')[0]
    env_params = get_env_params(args.env_name, env, testenv)
    # env_params['std'] = float(args.activate)
    # create the ddpg agent to interact with the environment

    # if alg == 'HER':
    #     ddpg_trainer = ddpg_HER(args, alg, env, args.env_name, env_params)
    # elif alg == 'TD3':
    #     ddpg_trainer = TD3(args, alg, env, args.env_name, env_params)
    # elif alg=='GAIL':
    #     ddpg_trainer = ddpg_GAIL(args, env, args.env_name, env_params)
    # elif alg=='GCSL':
    #     ddpg_trainer = ddpg_VAEHER(args, env, args.env_name, env_params)
    # elif alg=='GRSILTADD':
    #     ddpg_trainer = ddpg_TADDGRSIL(args, alg, env, args.env_name, env_params)
    # elif alg == 'GRSIL':
    #     ddpg_trainer = GRSL(args, alg, env, args.env_name, env_params)
    # else:
    if alg == 'ConsHER':
        ddpg_trainer = ddpg_ConsHER(args, alg, env, args.env_name, env_params)
    elif alg == 'HER' or alg == 'CounterHER' or alg == 'CounterHERv2' or alg=='CounterHERFar' or alg=='CounterHERShort':
        ddpg_trainer = ddpg_CounterHER(args, alg, env, args.env_name, env_params)
    elif alg == 'HERIGA':
        ddpg_trainer = ddpg_HERIGA(args, alg, env, args.env_name, env_params)
    logpath = os.path.join('/xx/code/HER/run', tag, tag + '-' + args.env_name + '-seed=' + str(args.seed))
    writer = SummaryWriter(logpath)
    ddpg_trainer.learn(writer, logpath, tag.split('_')[0])

def run(env, seed, tag, activate,w, alg, gpuid):
    if env[:6]=='Sawyer':
        args = get_args2(env, seed, activate, w, alg, gpuid)
    elif env == 'FetchReach-v1' or env == 'PointFourRooms':
        args = get_args3(env, seed, activate, w, alg, gpuid)
    elif env[:3]=='Ant':
        args = get_args5(env, seed, activate, w, alg, gpuid)
    elif env=='MultiGoal':
        args = get_args4(env, seed, activate, w, alg, gpuid)
    elif env[:10]=='FetchStack':
        args = get_args7(env, seed, activate, w, alg, gpuid)
    else:
        args = get_args(env, seed, activate, w, alg, gpuid)
    launch(args, tag)

if __name__ == '__main__':
    os.environ['OMP_NUM_THREADS'] = '2'
    os.environ['MKL_NUM_THREADS'] = '2'
    os.environ['IN_MPI'] = '2'
    torch.multiprocessing.set_start_method('spawn')# good solution !!!!
    envs = ['HandManipulateBlockRotateXYZ-v0','HandManipulateEggFull-v0','HandManipulatePenRotate-v0']
    envs = ['HandManipulateEggFull-v0','HandManipulatePenRotate-v0']
    envs = ['SawyerDoorHookResetFreeEnv-v0', 'SawyerPushAndReachArenaEnv-v0',
            'SawyerDoorOpen-v0', 'SawyerPushAndReachEnvHard-v0']

    envs = ['FetchPickAndPlace-v1','FetchSlide-v1','FetchPush-v1']
    # envs = ['FetchPush-v1','fetch:Drawer-open-v0','fetch:Bin-place-v0',]
    # envs=['FetchStack2Stage1-v1','FetchStack2Stage2-v1','FetchStack2Stage3-v1','FetchStack3Stage1-v1']
    # envs1 = ['AntMaze-v1','AntMazeL-v1',]
    # envs = ['AntFall-v1','AntPush-v1']
    # allenvs = [envs1, envs2]
    # envs = ['PointFourRooms']
    # envs = ['MultiGoal',]
    envs = ['FetchReach-v1']    


    tags1 = ['CounterHER_0_1',]*4
    tags2 = ['CounterHERv2_0_1']*4
    tags3 = ['CounterHERFar_0_1',]*4
    tags4 = ['CounterHERShort_0_1',]*4
    # tags4 = ['CounterMax_1_1',]
    tags5 = ['HER_1_1',]*4
    
    tags6 = ['HERIGA_1_0.2',]*4
    tags7 = ['HERIGA_1_0.05',]*4
    tags8 = ['HERIGA_1_0.1',]*4
    tags9 = ['ConsHER_1_0.5',]*5

    alltags = [tags6,tags7,tags8]
    alltags = [tags1,tags5,tags9,tags8]


    gpuids = [0,0,0,0,0,0]

    for i, env in enumerate(envs):
        process = []
        for j in range(len(alltags)):
            tag = alltags[j][i]
            gpuid = gpuids[i]
            w = tag.split('_')[-1]
            activate = tag.split('_')[1].split('_')[0]
            alg = tag.split('_')
            for seed in [0,1,2,3,4]:
                p = mp.Process(target=run, args=(env, seed, tag, activate, w, alg, gpuid))
                p.start()
                process.append(p)
            for p in process:
                p.join()