"""
Training behavior policies for FOCAL

"""

import click
import json
import os

import gym, gym.wrappers

import argparse
import multiprocessing as mp
from multiprocessing import Pool
from itertools import product


import metaworld,random
import numpy as np
import metaworld.policies as p
import copy


def collect_data(env_name):
    os.makedirs('./data/' +env_name, exist_ok=True)
    ml1 = metaworld.MT1(env_name, seed=1337)  # Construct the benchmark, sampling tasks

    env = ml1.train_classes[env_name]()  # Create an environment with task
    # print(ml1.train_tasks)
    env.train_tasks = ml1.train_tasks
    # task = 0
    task = ml1.train_tasks[0]
    env.set_task(task)
    env._freeze_rand_vec = False
    for i in range(10):
        s=env.reset()
        print(s[-3:])


    if env_name=='push-v2':
        policy = p.SawyerPushV2Policy
    elif env_name=='reach-v2':
        policy = p.SawyerReachV2Policy
    elif env_name=='pick-place-v2':
        policy = p.SawyerPickPlaceV2Policy
    elif env_name == 'basketball-v2':
        policy = p.SawyerBasketballV2Policy
    elif env_name=='push-wall-v2':
        policy = p.SawyerPushWallV2Policy
    elif env_name=='pick-place-wall-v2':
        policy = p.SawyerPickPlaceV2Policy
    elif env_name=='window-open-v2':
        policy = p.SawyerWindowOpenV2Policy
    elif env_name=='drawer-close-v2':
        policy = p.SawyerDrawerCloseV2Policy
    elif env_name=='handle-pull-side-v2':
        policy = p.SawyerHandlePullSideV2Policy
    elif env_name=='handle-pull-v2':#
        policy = p.SawyerHandlePullV2Policy
    elif env_name=='lever-pull-v2':
        policy = p.SawyerLeverPullV2Policy
    elif env_name=='peg-insert-side-v2':
        policy = p.SawyerPegInsertionSideV2Policy
    elif env_name=='pick-place-wall-v2':
        policy = p.SawyerPickPlaceWallV2Policy
    elif env_name=='pick-out-of-hole-v2':
        policy = p.SawyerPickOutOfHoleV2Policy
    elif env_name=='push-back-v2':
        policy = p.SawyerPushBackV2Policy  # bad data collection
    elif env_name=='plate-slide-v2':
        policy = p.SawyerPlateSlideV2Policy
    elif env_name=='plate-slide-side-v2':
        policy = p.SawyerPlateSlideSideV2Policy
    elif env_name=='plate-slide-back-v2':
        policy = p.SawyerPlateSlideBackV2Policy # bad data collection
    elif env_name=='plate-slide-back-side-v2':
        policy = p.SawyerPlateSlideBackSideV2Policy # bad data collection
    elif env_name=='peg-unplug-side-v2':
        policy = p.SawyerPegUnplugSideV2Policy# bad data collection
    elif env_name=='soccer-v2':
        policy = p.SawyerSoccerV2Policy
    elif env_name=='stick-push-v2':# bad data collection   bad 2
        policy = p.SawyerStickPushV2Policy
    elif env_name=='stick-pull-v2':# bad data collection  bad 2   button press 2000  box close 600  coffee push 29 disassemble 220
        policy = p.SawyerStickPullV2Policy
    elif env_name=='push-wall-v2':
        policy = p.SawyerPushWallV2Policy
    elif env_name=='reach-wall-v2':
        policy = p.SawyerReachWallV2Policy
    elif env_name=='shelf-place-v2':
        policy = p.SawyerShelfPlaceV2Policy
    elif env_name=='sweep-into-v2':# bad data collection
        policy = p.SawyerSweepIntoV2Policy
    elif env_name=='sweep-v2':
        policy = p.SawyerSweepV2Policy
    elif env_name=='window-close-v2':
        policy = p.SawyerWindowCloseV2Policy
    elif env_name=='hammer-v2':
        policy = p.SawyerHammerV2Policy
    # elif env_name=='soccer-v2':
    #     policy = p.SawyerHammerV2Policy
    else:
        NotImplementedError


    obs_all,action_all,reward_all,next_observation_all,terminals_all = [],[],[],[],[]
    success_cnt = 0
    while success_cnt < 1000:
        obs = env.reset()
        done = False
        episode_reward = 0
        trj = []
        step = 0
        success = 0
        while not done:
            # tmp_obs = copy.deepcopy(obs)

            # unscaled_action, _ = model.predict(obs, deterministic=False)
            # caled_action = self.policy.scale_action(unscaled_action)
            #
            # action = np.clip(scaled_action, -1, 1)
            #

            action = policy.get_action(policy, obs)
            # noise = np.random.randn(action.shape[0]) * 0.1
            noise = np.random.normal(0.0,0.05,action.shape[0])
            action = (action + noise).clip(-1, 1)
            if success_cnt>=10:
                if np.random.rand()<0.8:
                    action = np.random.rand(action.shape[0])*2-1
                    action = action.clip(-1, 1)
            if success_cnt>=800:
                if np.random.rand()<1.1:
                    action = np.random.rand(action.shape[0])*2-1
                    action = action.clip(-1, 1)
            new_obs, reward, done, info = env.step(action)
            # env.render()
            done = float(1) if step + 1 == 500 else done
            step += 1
            store_obs = copy.deepcopy(obs)
            store_new_obs = copy.deepcopy(new_obs)
            # store_obs[-3:] = 0
            # store_new_obs[-3:] = 0
            # trj.append([store_obs, action, reward, store_new_obs])
            obs_all.append(obs)
            action_all.append(action)
            reward_all.append(reward)
            next_observation_all.append(new_obs)
            terminals_all.append(done)
            obs = new_obs
            episode_reward += reward
            success += info['success']
        if 1:
            print(episode_reward, success, success_cnt)
            success_cnt+=1
        else:
            if np.random.rand() > 0.9:
                print(episode_reward, success, success_cnt)
                np.save(os.path.join('./data/' + env_name + '3/goal_idx%d' % goal_idx,
                                     f'trj_evalsample{success_cnt}_step{49500}.npy'), np.array(trj))
                success_cnt += 1

    dataset = {'observations':np.array(obs_all),'actions':np.array(action_all),'rewards':np.array(reward_all),'next_observations':np.array(next_observation_all),'terminals':np.array(terminals_all)}
    np.save('./data/'+env_name+'/data_randgoal_medium_rand2.npy',dataset)

    return


if __name__ == '__main__':
    env_name='pick-place-v2'
    collect_data(env_name)
