from asyncio import FastChildWatcher
from turtle import Turtle, numinput, shape
import numpy as np
import torch
import argparse
import os
import math
import gym
import sys
import random
import time
import json
# import dmc2gym
import copy

import utils
from logger import Logger
from video import VideoRecorder

from crsfd_sac import crsfdSacAgent,potential
from torchvision import transforms
from datetime import datetime
from peginhole.peginhole import SinglePeginHole
from collections import deque


def parse_args():
    parser = argparse.ArgumentParser()
    # environment
    parser.add_argument('--domain_name', default='peginhole')
    parser.add_argument('--task_name', default='4')
    parser.add_argument('--pre_transform_image_size', default=84, type=int)

    parser.add_argument('--image_size', default=84, type=int)
    parser.add_argument('--action_repeat', default=1, type=int)
    parser.add_argument('--frame_stack', default=3, type=int)
    # replay buffer
    parser.add_argument('--replay_buffer_capacity', default=200000, type=int)
    # train
    parser.add_argument('--agent', default='crsfd_sac', type=str)
    parser.add_argument('--init_steps', default=1000, type=int)
    parser.add_argument('--num_train_steps', default=200000, type=int)
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--hidden_dim', default=1024, type=int)
    # eval
    parser.add_argument('--eval_freq', default=10000, type=int)
    parser.add_argument('--num_eval_episodes', default=10, type=int)
    # critic
    parser.add_argument('--critic_lr', default=1e-3, type=float)
    parser.add_argument('--critic_beta', default=0.9, type=float)
    parser.add_argument('--critic_tau', default=0.01, type=float) # try 0.05 or 0.1
    parser.add_argument('--critic_target_update_freq', default=2, type=int) # try to change it to 1 and retain 0.01 above
    # actor
    parser.add_argument('--actor_lr', default=1e-3, type=float)
    parser.add_argument('--actor_beta', default=0.9, type=float)
    parser.add_argument('--actor_log_std_min', default=-10, type=float)
    parser.add_argument('--actor_log_std_max', default=2, type=float)
    parser.add_argument('--actor_update_freq', default=1, type=int)
    # encoder
    parser.add_argument('--encoder_type', default='identity', type=str)
    parser.add_argument('--encoder_feature_dim', default=50, type=int)
    parser.add_argument('--encoder_lr', default=1e-3, type=float)
    parser.add_argument('--encoder_tau', default=0.05, type=float)
    parser.add_argument('--num_layers', default=4, type=int)
    parser.add_argument('--num_filters', default=32, type=int)
    parser.add_argument('--crsfd_latent_dim', default=128, type=int)
    # sac
    parser.add_argument('--discount', default=0.99, type=float)
    parser.add_argument('--init_temperature', default=0.1, type=float)
    parser.add_argument('--alpha_lr', default=1e-4, type=float)
    parser.add_argument('--alpha_beta', default=0.5, type=float)
    # misc
    parser.add_argument('--seed', default=-1, type=int)
    parser.add_argument('--work_dir', default='.', type=str)
    parser.add_argument('--save_tb', default=False, action='store_true')
    parser.add_argument('--save_buffer', default=False, action='store_true')
    parser.add_argument('--save_video', default=False, action='store_true')
    parser.add_argument('--save_model', default=False, action='store_true')
    parser.add_argument('--detach_encoder', default=False, action='store_true')

    parser.add_argument('--log_interval', default=1000, type=int)
    parser.add_argument('--evaluate_dir', default=None, type=str)
    parser.add_argument('--evaluate_step', default=40000, type=int)
    parser.add_argument('--camera_id', default=5, type=int)
    parser.add_argument('--hole_height', default=0.8, type=float)

    parser.add_argument('--pretrain', default=None, type=str)
    parser.add_argument('--pretrain_step', default=40000, type=int)

    parser.add_argument('--load_buffer', default=None, type=str)
    parser.add_argument('--per_step_update', default=2, type=int)
    parser.add_argument('--pretrain_num', default=0, type=int)

    parser.add_argument('--demo_dir', default=None, type=str)
    parser.add_argument('--observation_mode', default='one', type=str)
    parser.add_argument('--action_mode', default='free', type=str)
    parser.add_argument('--horizon', default=50, type=int)
    parser.add_argument('--control_freq', default=10, type=int)
    parser.add_argument('--large_hole', default=False, action='store_true')
    parser.add_argument('--sparse', default=False, action='store_true')
    parser.add_argument('--img_demo', default=False, action='store_true')
    parser.add_argument('--stochastic', default=False, action='store_true')
    parser.add_argument('--demo_obs', default='standerd', type=str)
    parser.add_argument('--collect_imgbuffer', default=False, action='store_true')
    parser.add_argument('--demo_ratio', default=0.25, type=float)
    parser.add_argument('--demo_decay', default=False, action='store_true')
    parser.add_argument('--intrinsic_r', default=False, action='store_true')

    parser.add_argument('--reward_scale', default=10, type=float)
    parser.add_argument('--sparse3', default=False, action='store_true')

    parser.add_argument('--insert_depth', default=0.03, type=float)
    parser.add_argument('--self_imitation', default=False, action='store_true')
    parser.add_argument('--sil_buffer_size', default=5000, type=int)
    parser.add_argument('--sil_threshold', default=30, type=int)
    parser.add_argument('--bc_lambda', default=0, type=float)
    parser.add_argument('--bc_decay', default=0.0001, type=float)

    parser.add_argument('--use_potential', default=False, action='store_true')
    parser.add_argument('--potential_rate', default=0.2, type=float)

    parser.add_argument('--stop_after_success', default=False, action='store_true')

    parser.add_argument('--in_time', default=50, type=int)

    parser.add_argument('--robot', default='Panda', type=str)
    parser.add_argument('--obs_norm', default=False, action='store_true')

    parser.add_argument('--demo_early_stop', default=False, action='store_true')
    parser.add_argument('--evaluate_save_video', default=False, action='store_true')

    parser.add_argument('--sqil', default=False, action='store_true')

    parser.add_argument('--imitation_learning', default=False, action='store_true')

    parser.add_argument('--r_lambda', default=1, type=float)
    parser.add_argument('--hole_range', default=0, type=float)

    # parser.add_argument('--initialization_noise', default=None, type=str)

    args = parser.parse_args()
    return args


def evaluate(env, agent, video, num_episodes, L, step, args, obs_mean, obs_std):
    all_ep_rewards = []

    def run_eval_loop(sample_stochastically=True):
        start_time = time.time()
        prefix = 'stochastic_' if sample_stochastically else ''
        part_insert=0
        total_insert=0
        for i in range(num_episodes):
            obs = env.reset()
            if step %20000 == 0:
                video.init(enabled=(i == 0))
            done = False
            episode_reward = 0
            while not done:
                # center crop image
                if args.encoder_type == 'pixel':
                    obs = utils.center_crop_image(obs,args.image_size)
                with utils.eval_mode(agent):
                    if sample_stochastically:
                        action = agent.sample_action(obs)
                    else:
                        action = agent.select_action(obs)
                obs, reward, done, _ = env.step(action)
                if args.obs_norm:
                    obs = (obs-obs_mean)/obs_std
                if step %20000 == 0:
                    video.record(env)
                episode_reward += reward
            if step %20000 == 0:
                video.save('%d.mp4' % step)
            L.log('eval/' + prefix + 'episode_reward', episode_reward, step)
            all_ep_rewards.append(episode_reward)
            if env.get_peg_pos_to_hole()[2]<-0.045:
                total_insert += 1
            if env.get_peg_pos_to_hole()[2]<-0.002:
                part_insert += 1
        
        L.log('eval/' + prefix + 'eval_time', time.time()-start_time , step)
        mean_ep_reward = np.mean(all_ep_rewards)
        best_ep_reward = np.max(all_ep_rewards)
        L.log('eval/' + prefix + 'mean_episode_reward', mean_ep_reward, step)
        L.log('eval/' + prefix + 'best_episode_reward', best_ep_reward, step)
        L.log('eval/' + prefix + 'total_insert', total_insert, step)
        L.log('eval/' + prefix + 'part_insert', part_insert, step)
        return mean_ep_reward

    score = run_eval_loop(sample_stochastically=False)
    L.dump(step)

    return score


def make_agent(obs_shape, action_shape, args, device):
    if args.agent == 'crsfd_sac':
        return crsfdSacAgent(
            obs_shape=obs_shape,
            action_shape=action_shape,
            device=device,
            hidden_dim=args.hidden_dim,
            discount=args.discount,
            init_temperature=args.init_temperature,
            alpha_lr=args.alpha_lr,
            alpha_beta=args.alpha_beta,
            actor_lr=args.actor_lr,
            actor_beta=args.actor_beta,
            actor_log_std_min=args.actor_log_std_min,
            actor_log_std_max=args.actor_log_std_max,
            actor_update_freq=args.actor_update_freq,
            critic_lr=args.critic_lr,
            critic_beta=args.critic_beta,
            critic_tau=args.critic_tau,
            critic_target_update_freq=args.critic_target_update_freq,
            encoder_type=args.encoder_type,
            encoder_feature_dim=args.encoder_feature_dim,
            encoder_lr=args.encoder_lr,
            encoder_tau=args.encoder_tau,
            num_layers=args.num_layers,
            num_filters=args.num_filters,
            log_interval=args.log_interval,
            detach_encoder=args.detach_encoder,
            crsfd_latent_dim=args.crsfd_latent_dim,
            reward_scale=args.reward_scale,
            bc_lambda=args.bc_lambda,
            bc_decay=args.bc_decay,
        )
    else:
        assert 'agent is not supported: %s' % args.agent

def main():
    args = parse_args()
    if args.seed == -1: 
        args.__dict__["seed"] = np.random.randint(1,1000000)
    utils.set_seed_everywhere(args.seed)

    hole_pos=[-0.15, -0.01, 0.9] if args.robot == 'Jaco' else None
    hole_pos=None
    env = SinglePeginHole(observation_mode = args.observation_mode, action_mode = args.action_mode, robots=[args.robot], peg_class=args.task_name, horizon=args.horizon, large_hole=args.large_hole, hole_height=args.hole_height, control_freq=args.control_freq, sparse=args.sparse, sparse3=args.sparse3, depth=args.insert_depth,has_offscreen_renderer=False, stop_after_success=args.stop_after_success, hole_pos=hole_pos,hole_range=args.hole_range)
    env_eval = SinglePeginHole(observation_mode = args.observation_mode, action_mode = args.action_mode, robots=[args.robot], peg_class=args.task_name, horizon=args.horizon, large_hole=args.large_hole, hole_height=args.hole_height, control_freq=args.control_freq, sparse=args.sparse, sparse3=args.sparse3, depth=args.insert_depth,has_offscreen_renderer=False, stop_after_success=args.stop_after_success,hole_pos=hole_pos,hole_range=args.hole_range)

    env._max_episode_steps = args.horizon
    env_eval._max_episode_steps = args.horizon

    # env.seed(args.seed)

    # stack several consecutive frames together
    if args.encoder_type == 'pixel' or args.observation_mode == 'ee':
        env = utils.FrameStack(env, k=args.frame_stack)
        env_eval = utils.FrameStack(env_eval, k=args.frame_stack)
    
    # make directory
    ts = time.gmtime() 
    ts = time.strftime("%m-%d", ts)    
    env_name = args.domain_name + '-' + args.task_name
    timenow = str(datetime.now())
    timenow = timenow[0:10] + '_' + timenow[11:13] + '_' + timenow[14:16]+ '_' + timenow[17:19]
    exp_name = env_name + '_' + timenow
    folder = './tmp_demo/' if args.load_buffer is not None else './tmp_ee/'
    folder = './tmp_pixel/' if args.encoder_type == 'pixel' else folder
    folder = './tmp602/'
    work_dir = folder+args.domain_name
    if not os.path.exists(work_dir): os.mkdir(work_dir)

    args.work_dir = work_dir + '/'  + exp_name
    if not os.path.exists(args.work_dir): os.mkdir(args.work_dir)
    utils.make_dir(args.work_dir)
    video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
    model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
    buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))
    video = VideoRecorder(video_dir if args.save_video else None, height=256, width=256)
    with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    # action, state
    action_shape = env.action_space.shape
    if args.encoder_type == 'pixel':
        obs_shape = (3*args.frame_stack, args.image_size, args.image_size)
        pre_aug_obs_shape = (3*args.frame_stack,args.pre_transform_image_size,args.pre_transform_image_size)
    else:
        obs_shape = env.observation_space.shape
        pre_aug_obs_shape = obs_shape
    print(pre_aug_obs_shape)
    print(action_shape)

    # replaybuffer, agent, logger
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    replay_buffer = utils.ReplayBuffer(
        obs_shape=pre_aug_obs_shape,
        action_shape=action_shape,
        capacity=args.replay_buffer_capacity,
        batch_size=args.batch_size,
        device=device,
        image_size=args.image_size,
        demo_ratio=args.demo_ratio,
        demo_decay=args.demo_decay,
        self_imitate_num=args.sil_buffer_size if args.self_imitation else 0,
    )
    trajectory = utils.OneTrajectory(obs_shape=pre_aug_obs_shape, action_shape=action_shape,)
    agent = make_agent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        args=args,
        device=device
    )
    L = Logger(args.work_dir, use_tb=args.save_tb)

    # pretrain or load buffer
    obs_mean,obs_std = None,None 
    if args.pretrain is not None:
        agent.load(args.pretrain, args.pretrain_step)
    if args.load_buffer is not None:
        obs_mean,obs_std=replay_buffer.load_peg2(args.load_buffer,norm=args.obs_norm,sqil=args.sqil,imitation_learning=args.imitation_learning,r_lambda=args.r_lambda)
        for i in range(args.pretrain_num):
            agent.update(replay_buffer, L, -1)
    if args.use_potential:
        potential_f=potential(obs_dim=7).to(device)
        potential_f.load_state_dict(torch.load('potential_weight/value.pt'))

    # start trainnig
    best_score = 5
    log_time = 0
    episode, episode_reward, done = 0, 0, True
    start_time = time.time()
    good_traj_count = 0
    in_time = 0

    for step in range(args.num_train_steps):
        # evaluate agent periodically
        if step % args.eval_freq == 0:
            L.log('eval/episode', episode, step)
            score = evaluate(env_eval, agent, video, args.num_eval_episodes, L, step,args, obs_mean, obs_std)
            if args.save_model and score>best_score:
                # agent.save_crsfd(model_dir, step)
                agent.save(model_dir, step)
                best_score = score
            if args.save_buffer:
                replay_buffer.save(buffer_dir)

        # if done or env.get_peg_pos_to_hole()[2]<-args.insert_depth:
        if done or in_time>=args.in_time:
            if step > 0:
                if step >= args.log_interval*log_time:
                    L.log('train/duration', time.time() - start_time, step)
                    L.dump(step)
                start_time = time.time()
            if step >= args.log_interval*log_time:
                L.log('train/episode_reward', episode_reward, step)
                log_time += 1

            # collect better traj
            if args.self_imitation:
                # if episode_reward>args.sil_threshold and episode_reward>score:
                if episode_reward>args.sil_threshold:
                # if (episode_reward>20 and step<=150000) or (episode_reward>600 and 200000>step>150000):
                    good_traj_count+=1
                    if good_traj_count%10==0:
                        print("good_traj_num:{}".format(good_traj_count))
                    obses, actions, rewards, next_obses, dones = trajectory.get_all()
                    if args.sqil:
                        rewards += 2
                    replay_buffer.add_batch(obses, actions, rewards, next_obses, dones)
                trajectory.clear()
            obs = env.reset()
            if args.obs_norm:
                obs = (obs-obs_mean)/obs_std
            done = False
            episode_reward = 0
            episode_step = 0
            episode += 1
            in_time = 0
            if step % args.log_interval == 0:
                L.log('train/episode', episode, step)

        # sample action for data collection
        if step < args.init_steps:
            action = env.action_space.sample()
        else:
            with utils.eval_mode(agent):
                action = agent.sample_action(obs)

        # run training update
        if step >= args.init_steps:
            num_updates = args.per_step_update 
            for _ in range(num_updates):
                agent.update(replay_buffer, L, step)

        next_obs, reward, done, _ = env.step(action)
        if args.imitation_learning:
            reward=0
        if args.obs_norm:
            next_obs = (next_obs-obs_mean)/obs_std

        # allow infinit bootstrap
        done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(done)
        episode_reward += reward
        # intrinsic reward, around demo state
        pos = env.get_peg_pos_to_hole()
        if pos[2]<-args.insert_depth:
            in_time += 1
        if args.intrinsic_r:
            reward += env.intrinsic_reward(r_scale=args.r_lambda)
            # reward += 4*env.reward()
        if args.use_potential:
            pos = env.get_peg_pos_to_hole()
            if pos[0] < 5e-2 and pos[1] < 5e-2:
                with torch.no_grad():
                    delta_r = 0.99*potential_f(torch.from_numpy(next_obs[-7:][np.newaxis,:]).to(device).float()).detach().cpu().numpy()-potential_f(torch.from_numpy(obs[-7:][np.newaxis,:]).to(device).float()).detach().cpu().numpy()
                    # delta_r = args.potential_rate*np.clip(delta_r,0,20)
                    delta_r_norm = 2 if delta_r>0 else 0
                reward+=delta_r_norm
                if pos[2]<0.005:
                    reward += 2
                # print(delta_r)

        replay_buffer.add(obs, action, reward, next_obs, done_bool)
        if args.self_imitation:
            trajectory.add(obs, action, reward, next_obs, done_bool)
        obs = next_obs
        episode_step += 1

#########################################################################
#########################################################################
def evaluate_peg(sample_stochastically=False, num_episodes=1,step=40000,save_video=True):
    args = parse_args()
    video_dir = os.path.join(args.evaluate_dir, 'video')
    model_dir = os.path.join(args.evaluate_dir, 'model')
    video = VideoRecorder(video_dir,camera_id=args.camera_id, fps=args.control_freq)
    video.init(enabled=1)

    # create env
    hole_pos=[-0.15, -0.01, 0.9] if args.robot == 'Jaco' else None
    hole_pos =None
    env = SinglePeginHole(observation_mode = args.observation_mode, action_mode = args.action_mode, robots=[args.robot], peg_class=args.task_name, has_offscreen_renderer=True, horizon=args.horizon, large_hole=args.large_hole, hole_height=args.hole_height, control_freq=args.control_freq, hole_pos=hole_pos)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    env._max_episode_steps = args.horizon
    action_shape = env.action_space.shape
    obs_shape = env.observation_space.shape

    # agent
    agent = make_agent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        args=args,
        device=device
    )
    agent.load(model_dir, step)

    # start evaluate
    success_rate = 0
    for _ in range(num_episodes):
        obs = env.reset()
        done = False
        episode_reward = 0
        while not done:
            with utils.eval_mode(agent):
                if sample_stochastically:
                    action = agent.sample_action(obs)
                else:
                    action = agent.select_action(obs)
            obs, reward, done, _ = env.step(action)
            if done:
                print(env.get_peg_pos_to_hole()[2])
            video.record(env)
            episode_reward += reward
        print('{},{}'.format(step, episode_reward))

    # save video
    if save_video:
        timenow = str(datetime.now())[0:-10]
        timenow = timenow[0:10] + '_' + timenow[11:13] + '_' + timenow[14:16]
        video.save('{}_{}_{}_{}.mp4'.format(step, timenow, args.task_name, args.large_hole))

def collect_demo(collet_num=20, model_step=400001, threshold=-0.05, sparse=False,sparse3=False,early_stop=True, demo_obs='standerd', sample_stochastically=True):
    args = parse_args()
    model_dir = os.path.join(args.demo_dir, 'model')
    img_demo = True if demo_obs == 'pixel' else False
    # load env, agent
    hole_pos=[-0.15, -0.01, 0.9] if args.robot == 'Jaco' else None
    hole_pos = None
    env = SinglePeginHole(observation_mode = args.observation_mode, action_mode = args.action_mode, robots=[args.robot], peg_class=args.task_name, has_offscreen_renderer=img_demo, horizon=args.horizon, large_hole=args.large_hole, control_freq=args.control_freq, depth=args.insert_depth,sparse3=sparse3, hole_pos=hole_pos)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    action_shape = env.action_space.shape
    obs_shape = env.observation_space.shape
    agent = make_agent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        args=args,
        device=device
    )
    agent.load(model_dir, model_step)
    
    # create demo buffer
    if demo_obs == 'pixel':
        size = args.image_size
        obs_shape = [3*args.frame_stack, size, size]
        frames = deque([], maxlen=3)
    elif demo_obs == 'eef':
        obs_shape = [18]
    elif demo_obs == 'eef2':
        obs_shape = [19]
    elif demo_obs == 'without_torque':
        obs_shape = [33]
    obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
    print(obs_shape)
    print(action_shape)
    obs_array = np.zeros((collet_num, *obs_shape), dtype=obs_dtype)
    action_array = np.zeros((collet_num, *action_shape))
    r_array = np.zeros((collet_num, 1))
    nexto_array = np.zeros((collet_num, *obs_shape), dtype=obs_dtype)
    done_array = np.zeros((collet_num, 1))

    # start collection
    obs = env.reset()
    done = False
    step = 0
    if img_demo == True:
        img = env.render(width=size, height=size, camera_id=5)
        img = np.transpose(img, (2,0,1)).copy()
        for _ in range(3):
            frames.append(img)
        frame = np.concatenate(list(frames), axis=0)
    elif demo_obs == 'eef':
        eef = env.get_eef_info()
    elif demo_obs == 'eef2':
        eef = env.get_eef_info2()
    elif demo_obs == 'without_torque':
        eef = env.without_torque()

    while(step < collet_num):
        with utils.eval_mode(agent):
            if sample_stochastically:
                action = agent.sample_action(obs)
            else:
                action = agent.select_action(obs)
        next_obs, reward, done, _ = env.step(action)
        if sparse:
            reward = env.sparse_reward()
        elif sparse3:
            reward = env.sparse_reward3(depth=args.insert_depth)
        elif env.get_peg_pos_to_hole()[2] < -0.045:
            reward = 50
        # choose the demo model
        if img_demo == True:
            next_img = env.render(width=size, height=size, camera_id=5)
            next_img = np.transpose(next_img, (2,0,1)).copy()
            frames.append(next_img)
            next_frame = np.concatenate(list(frames), axis=0)

            obs_array[step], action_array[step], r_array[step], nexto_array[step], done_array[step] = frame, action, reward, next_frame, done
            frame = next_frame
        elif demo_obs == 'eef':
            next_eef = env.get_eef_info()
            obs_array[step], action_array[step], r_array[step], nexto_array[step], done_array[step] = eef, action, reward, next_eef, done
            eef = next_eef
        elif demo_obs == 'eef2':
            next_eef = env.get_eef_info2()
            obs_array[step], action_array[step], r_array[step], nexto_array[step], done_array[step] = eef, action, reward, next_eef, done
            eef = next_eef
        elif demo_obs == 'without_torque':
            next_eef = env.without_torque()
            obs_array[step], action_array[step], r_array[step], nexto_array[step], done_array[step] = eef, action, reward, next_eef, done
            eef = next_eef
        else:
            obs_array[step], action_array[step], r_array[step], nexto_array[step], done_array[step] = obs, action, reward, next_obs, done
        obs = next_obs
        step += 1
        print('reward:{},done:{},action:{}'.format(reward,done,action))

        # check done
        if (early_stop and env.get_peg_pos_to_hole()[2] < -0.045) or done == 1:
            text = 'success' if env.get_peg_pos_to_hole()[2] < -0.045 else 'time_limit'
            print(text+str(env.get_peg_pos_to_hole()[2]))
            obs = env.reset()
            done = False

    # save demo buffer
    if not os.path.exists('./demo/{}'.format(args.task_name)): os.mkdir('./demo/{}'.format(args.task_name))
    s = 'stochastic' if args.stochastic else 'deterministic'
    a = 'sparse' if sparse else 'dense'
    b = 'large' if args.large_hole else 'small'
    c = args.demo_obs
    np.savez('./demo_eef_jaco/{}/panda_{}_{}_{}_{}_stop.npz'.format(args.task_name, s,a, b, c), obs_array,action_array,r_array,nexto_array,done_array)


if __name__ == '__main__':
    torch.multiprocessing.set_start_method('spawn')
    args = parse_args()
    if  args.evaluate_dir is not None:
        evaluate_peg(num_episodes=10,step=args.evaluate_step,save_video=args.evaluate_save_video)
    elif args.demo_dir is not None:
        collect_demo(collet_num=2000, model_step=args.evaluate_step,demo_obs=args.demo_obs, sample_stochastically=args.stochastic,sparse=args.sparse, sparse3=args.sparse3,early_stop=args.demo_early_stop)
    else:
        main()
        