# An example of generating expert demos by MetaWorld's scripted policies.
# Warning: The scripted policies may generate failure trajectories. You can write a task-dependent criterion to filter the generated trajectories.

import sys
sys.path.append("../RL")
import metaworld
import random
import metaworld.policies as policies
import cv2
import numpy as np

import pickle
from pathlib import Path
from collections import deque

from video import VideoRecorder

num_demos = 200
num_train_demos = 100
max_fail_num = 20000

TEXT_DEICRIPTION = {
	'window-close-v2': 'Closing window',
	'drawer-open-v2': 'Opening drawer',
	'door-open-v2': 'Opening door',
	'bin-picking-v2': 'Picking cube from bin and placing it in another bin',
	'button-press-topdown-v2': 'Pressing button from top',
	'door-unlock-v2': 'Unlocking door',
	'basketball-v3': 'Moving ball to above basket',
	'plate-slide-v2': 'Sliding plate into gate',
	"hand-insert-v2": 'Inserting hand into hole',
	"peg-insert-side-v2": 'Inserting peg into hole',
	'assembly-v3': 'Assembling ring to rod',
	'push-wall-v2': 'Pushing object to other side of wall',
	'soccer-v2': 'Pushing soccer ball to goal',
	'disassemble-v2': 'Disassembling ring from rod',
	'pick-place-wall-v3': 'Picking object and placing it to other side of wall',
	'pick-place-v2': 'Picking object and placing it',
	'lever-pull-v2': 'Pulling lever',
	'stick-pull-v2': 'Pulling kettle with stick',
	'shelf-place-v2': 'Placing object on shelf',
	'button-press-wall-v2': 'Pressing button behind wall',
	'box-close-v2': 'Closing box',
	'stick-push-v2': 'Pushing kettle with stick',
	'handle-pull-v2': 'Pulling handle up',
	'door-lock-v2': 'Locking door',
	'button-press-v2': 'Pressing button',
	'window-open-v2': 'Opening window',
}

env_names = ['drawer-open-v2', 'basketball-v3', 'disassemble-v2', 'window-close-v2', 'window-open-v2', 'door-open-v2', 'stick-push-v2', 'button-press-topdown-v2', 'plate-slide-v2', 'lever-pull-v2']

POLICY = {
	'hammer-v2': policies.SawyerHammerV2Policy,
	'drawer-close-v2': policies.SawyerDrawerCloseV2Policy,
	'drawer-open-v2': policies.SawyerDrawerOpenV2Policy,
	'door-open-v2': policies.SawyerDoorOpenV2Policy,
	'bin-picking-v2': policies.SawyerBinPickingV2Policy,
	'button-press-topdown-v2': policies.SawyerButtonPressTopdownV2Policy,
	'door-unlock-v2': policies.SawyerDoorUnlockV2Policy,
	'basketball-v3': policies.SawyerBasketballV2Policy,
	'plate-slide-v2': policies.SawyerPlateSlideV2Policy,
	"hand-insert-v2": policies.SawyerHandInsertV2Policy,  
	"peg-insert-side-v2": policies.SawyerPegInsertionSideV2Policy,  
	'assembly-v3': policies.SawyerAssemblyV2Policy,
	'push-wall-v2': policies.SawyerPushWallV2Policy,
	'soccer-v2': policies.SawyerSoccerV2Policy,
	'disassemble-v2': policies.SawyerDisassembleV2Policy,
	'pick-place-wall-v3': policies.SawyerPickPlaceWallV2Policy,
	'pick-place-v2': policies.SawyerPickPlaceV2Policy,
	'lever-pull-v2': policies.SawyerLeverPullV2Policy,
	'stick-pull-v2': policies.SawyerStickPullV2Policy,
	'shelf-place-v2': policies.SawyerShelfPlaceV2Policy,
	'window-close-v2': policies.SawyerWindowCloseV2Policy,
	'reach-v2': policies.SawyerReachV2Policy,
	'button-press-wall-v2': policies.SawyerButtonPressWallV2Policy,
	'box-close-v2': policies.SawyerBoxCloseV2Policy,
	'stick-push-v2': policies.SawyerStickPushV2Policy,
	'handle-pull-v2': policies.SawyerHandlePullV2Policy,
	'door-lock-v2': policies.SawyerDoorLockV2Policy,
	'button-press-v2': policies.SawyerButtonPressV2Policy,
	'window-open-v2': policies.SawyerWindowOpenV2Policy,
}

CAMERA = {
	'hammer-v2': 'corner3',
	'drawer-close-v2': 'corner',
	'drawer-open-v2': 'corner',
	'door-open-v2': 'corner3',
	'bin-picking-v2': 'corner',
	'button-press-topdown-v2': 'corner',
	'door-unlock-v2': 'corner',
	'basketball-v3': 'corner',
	'plate-slide-v2': 'corner',
	'hand-insert-v2': 'corner',
	'peg-insert-side-v2': 'corner3',
	'assembly-v3': 'corner',
	'push-wall-v2': 'corner',
	'soccer-v2': 'corner',
	'disassemble-v2': 'corner',
	'pick-place-wall-v3': 'corner3',
	'pick-place-v2': 'corner3',
	'lever-pull-v2': 'corner4',
	'stick-pull-v2': 'corner3',
	'shelf-place-v2': 'corner',
	'window-close-v2': 'corner3',
	'reach-v2': 'corner3',
	'button-press-wall-v2': 'corner',
	'box-close-v2': 'corner3',
	'stick-push-v2': 'corner',
	'handle-pull-v2': 'corner3',
	'door-lock-v2': 'corner',
	'button-press-v2': 'corner',
	'window-open-v2': 'corner3',
}

NUM_STEPS = {
	'hammer-v2': 125,
	'drawer-close-v2': 125,
	'drawer-open-v2': 125,
	'door-open-v2': 125,
	'bin-picking-v2': 175,
	'button-press-topdown-v2': 125,
	'door-unlock-v2': 125,
	'basketball-v3': 175,
	'plate-slide-v2': 125,
	'hand-insert-v2': 125,
	'peg-insert-side-v2': 150,
	'assembly-v3': 175,
	'push-wall-v2': 175,
	'soccer-v2': 125,
	'disassemble-v2': 125,
	'pick-place-wall-v3': 175,
	'pick-place-v2': 125,
	'lever-pull-v2': 125,
	'stick-pull-v2': 175,
	'shelf-place-v2': 175,
	'window-close-v2': 125,
	'reach-v2': 125,
	'button-press-wall-v2': 125,
	'box-close-v2': 175,
	'stick-push-v2': 125,
	'handle-pull-v2': 175,
	'door-lock-v2': 125,
	'button-press-v2': 125,
	'window-open-v2': 125,
}

for num in range(len(env_names)):
	env_name = env_names[num]
	save_dir = Path("./demos") / env_name 
	save_dir.mkdir(parents=True, exist_ok=True)
	text_csv = save_dir / 'text.csv'
	label_txt = save_dir / 'label_all.txt'
	label_text_train = save_dir / 'label.txt'
	label_text_val = save_dir / 'label_val.txt'

	print(f"Generating demo for: {env_name}")
	policy = POLICY[env_name]()
	mt1 = metaworld.MT1(env_name)
	env = mt1.train_classes[env_name]()
	video_recorder = VideoRecorder(save_dir, camera_name=CAMERA[env_name])

	images_list = list()
	large_images_list = list()
	observations_list = list()
	actions_list = list()
	rewards_list = list()

	count = 0
	episode = 0
	fail_episode = 0
	with open(text_csv, 'a') as f:
		f.write(f'id,name\n')
		f.write(f'0,{TEXT_DEICRIPTION[env_name]}\n')
	while episode < num_demos and fail_episode < max_fail_num:
		video_recorder.init(env)
		print(f"Episode {episode}")
		images = list()
		large_images = list()
		observations = list()
		actions = list()
		rewards = list()
		image_stack = deque([], maxlen=3)
		large_image_stack = deque([], maxlen=3)
		goal_achieved = 0

		task = mt1.train_tasks[count%50]
		count += 1
		env.set_task(task)

		observation = env.reset()
		move_step = random.randint(0, 50)
		noise_level = 2
		num_steps = NUM_STEPS[env_name] + move_step
		small_noise_level = 0.5
		
		
		for step in range(num_steps):
			pixel = env.render(offscreen=True, camera_name=CAMERA[env_name])
			frame = cv2.resize(pixel.copy(), (84,84))
			frame = np.transpose(frame, (2,0,1))
			image_stack.append(frame)
			while(len(image_stack)<3):
				image_stack.append(frame)
			images.append(np.concatenate(image_stack, axis=0))
			large_frame = cv2.resize(pixel.copy(), (224,224))
			large_frame = np.transpose(large_frame, (2,0,1))
			large_image_stack.append(large_frame)
			while(len(large_image_stack)<3):
				large_image_stack.append(large_frame)
			large_images.append(np.concatenate(large_image_stack, axis=0))
			action = policy.get_action(observation)
			
			if step == 0:
				first_action = action
				print('first_action:', first_action)
				noise = np.random.normal(0, noise_level, size=action.shape)
				first_action = np.clip(first_action, -1.0, 1.0)
				move_action = first_action + noise
				move_action = np.clip(move_action, -1.0, 1.0)
			if step <= move_step:
				action = move_action
			action = np.clip(action, -1.0, 1.0)
			
			observation[-3:] = 0
			if step > move_step:
				actions.append(action)
				observations.append(observation)
			observation, reward, done, info = env.step(action)
			if step > move_step:
				rewards.append(reward)
				video_recorder.record(env)
				goal_achieved += info['success'] 
			if goal_achieved >= 10:
				break

		if not goal_achieved:
			fail_episode += 1
			continue
		episode = episode + 1
		if episode <= num_train_demos:
			images_list.append(np.array(images))
			large_images_list.append(np.array(large_images))
			observations_list.append(np.array(observations))
			actions_list.append(np.array(actions))
			rewards_list.append(np.array(rewards))
			video_recorder.save(f'demo{episode}.mp4')

		video_name = f'{env_name}_{episode}.mp4'
		video_recorder.save(video_name)
		with open(label_txt, 'a') as f:
			f.write(f'{video_name} 0\n')
		if episode <= num_train_demos:
			with open(label_text_train, 'a') as f:
				f.write(f'{video_name} 0\n')
		else:
			with open(label_text_val, 'a') as f:
				f.write(f'{video_name} 0\n')
	
	with open(label_txt, 'rb+') as f:
		f.seek(-1, 2)
		f.truncate()
	with open(label_text_train, 'rb+') as f:
		f.seek(-1, 2)
		f.truncate()
	with open(label_text_val, 'rb+') as f:
		f.seek(-1, 2)
		f.truncate()