import multiprocessing as mp
import argparse

import numpy as np
# from perlin_noise import PerlinNoise

# import robosuite
import sys
from robosuite import make
from robosuite import load_controller_config
from controller.stack_policy import StackPolicy
from controller.lift_policy import LiftPolicy, LiftCausalPolicy
from collector.gym_wrapper import GymStackWrapper, GymLiftWrapper, GymLiftCausalWrapper


WRAPPER = {
    "LiftCausal": GymLiftWrapper,
    "StackCausal": GymStackWrapper,
    "CausalPick": GymLiftCausalWrapper,
}

POLICY = {
    "LiftCausal": LiftPolicy,
    "StackCausal": StackPolicy,
    "CausalPick": LiftCausalPolicy,
}

def worker(process_id, control_freq, task, horizon, episode_per_job, spurious_type, seed, random_type='random'):
    # initialize an environment with offscreen renderer
    env = make(
        task,
        'Kinova3',
        horizon=horizon,
        control_freq=control_freq,
        has_renderer=False,
        has_offscreen_renderer=False,
        ignore_done=False,
        use_camera_obs=False,
        use_object_obs=True,
        controller_configs=load_controller_config(default_controller='OSC_POSITION'),
        spurious_type=spurious_type,
    )
    env = WRAPPER[task](env)
    env.seed(seed)
    np.random.seed(seed)

    # policy
    policy = POLICY[task](Kp=10, Kd=1.0, atol=1e-3)
    trajectories = []
    episode_cnt = 0
    while episode_cnt < episode_per_job:
        obs_list = []
        act_list = []
        obs_next_list=[]
        reward_list=[]
        done_list=[]
        obs, info = env.reset(return_info=True)
        policy.reset()
        success=False
        # model_xml = env.env.sim.model.get_xml()
        # mujoco_state = np.array(env.env.sim.get_state().flatten())
        
        save_this_episode = False
        time_step=0
        while True:
            if random_type=="medium" and task=="LiftCausal" and time_step<=12:
                action=policy.random_action()
            elif random_type=="random" and task=="LiftCausal" and time_step>=14:
                action=policy.random_action()
            else:
                action = policy.step(info)
            
            # add noise
            #idx += 1
            #p_noise = np.array([x_noise(idx/factor), y_noise(idx/factor), z_noise(idx/factor), np.random.uniform(-0.01, 0.01)])
            #action += p_noise * 0.0
            action=np.clip(action,-np.ones_like(action), np.ones_like(action))
            # collect data
            obs_list.append(obs)
            #mujoco_state_list.append(mujoco_state)
            act_list.append(action)
            # if task == 'Language':
            #     instr, _ = policy.scene_grounding(info)
            #     black_list = ['near', 'contact', 'in']
            #     instr = [i_i for i_i in instr if i_i[0] not in black_list]
            # else:
            #     instr = None
            #satisfied_instr.append(instr)

            obs, reward, done, info = env.step(action)
            obs_next_list.append(obs)
            reward_list.append(reward)
            done_save=0
            if done:
                done_save=1
            done_list.append(done_save)

            #assert len(obs_list) == len(act_list) == len(satisfied_instr), "Lengths of obs, act, instr are not equal"
            assert len(obs_list) == len(act_list) == len(reward_list) == len(done_list) == len(obs_next_list), "Lengths of obs and act are not equal"
            time_step+=1
            # if env.check_success():
            if done or env.check_success():
                if(env.check_success()):
                    success=True
                    # save_this_episode = True
                save_this_episode = True
                break

        if save_this_episode: #and time_step<50:
            episode_cnt += 1
            obs_list = np.stack(obs_list, axis=0)
            act_list = np.stack(act_list, axis=0)
            reward_list = np.stack(reward_list, axis=0)
            done_list= np.stack(done_list, axis=0)
            obs_next_list=np.stack(obs_next_list, axis=0)
            trajectories.append({
                'obs': obs_list, 
                'acts': act_list, 
                'dones': done_list, 
                'rewards': reward_list,
                'obs_next': obs_next_list,
                'success': success
            })
            print('Worker {} - Success: {} - Length: {}'.format(process_id, success, len(obs_list)))
        else:
            print('Worker {} - Fail'.format(process_id))

    env.close()
    return trajectories


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="LiftCausal", choices=["LiftCausal", "StackCausal"])
    parser.add_argument("--horizon", type=int, default=30) #200
    parser.add_argument("--control_freq", type=int, default=5)
    parser.add_argument("--num_job", type=int, default=100)
    parser.add_argument("--num_traj", type=int, default=1000)
    parser.add_argument("--episode_per_job", type=int, default=1)
    parser.add_argument("--spurious_type", type=str, default="xnr", choices=["xnr","xpr"])
    parser.add_argument("--type", type=str, default="expert", choices=["expert","medium", "random"])
    args = parser.parse_args()

    collect_number = args.num_traj*args.episode_per_job
    print("Number of cpu : ", mp.cpu_count())
    print("Collect number: ", collect_number)

    pool = mp.Pool(args.num_job)
    seed_list = np.random.choice(list(range(0, 10000000)), args.num_traj, replace=False)
    function_list = []
    for i in range(args.num_traj):
        # f=None
        # if i>=600:
        #     f = pool.apply_async(worker, args=(i, args.control_freq, args.task, args.horizon, args.episode_per_job, "xpr", seed_list[i], args.type))
        # else:
        #     f = pool.apply_async(worker, args=(i, args.control_freq, args.task, args.horizon, args.episode_per_job, "xnr", seed_list[i], args.type))
        f = pool.apply_async(worker, args=(i, args.control_freq, args.task, args.horizon, args.episode_per_job, args.spurious_type, seed_list[i], args.type))
        function_list.append(f)

    # get results
    trajectories = []
    for f in function_list:
        trajectories += f.get(timeout=1200)

    # save results during each loop
    print('Total number of trajectories: {}'.format(len(trajectories)))
    success_count = sum([1 for traj in trajectories if traj['success']])
    success_rate = success_count / len(trajectories) * 100
    print('success_rate:',success_rate)
    np.save('./data/robosuite/lift_' + args.type + '_' + str(collect_number) + '.npy', trajectories, allow_pickle=True)

    pool.close()
