import sys
sys.path.append("../ROT")
import metaworld
import random
import metaworld.policies as policies
import cv2
import numpy as np

import pickle
from pathlib import Path
from collections import deque

from video import VideoRecorder

# env_names = ["hammer-v2", "drawer-close-v2", "drawer-open-v2", "door-open-v2", "bin-picking-v2", "button-press-topdown-v2", "door-unlock-v2", "basketball-v2", "plate-slide-v2", "hand-insert-v2", "peg-insert-side-v2", "peg-insert-side-v3"]
env_names = ["door-lock-v2"]
num_demos = 10

POLICY = {
	'hammer-v2': policies.SawyerHammerV2Policy,
	'drawer-close-v2': policies.SawyerDrawerCloseV2Policy,
	'drawer-open-v2': policies.SawyerDrawerOpenV2Policy,
	'door-open-v2': policies.SawyerDoorOpenV2Policy,
	'bin-picking-v2': policies.SawyerBinPickingV2Policy,
	'button-press-topdown-v2': policies.SawyerButtonPressTopdownV2Policy,
	'door-unlock-v2': policies.SawyerDoorUnlockV2Policy,
	'basketball-v2': policies.SawyerBasketballV2Policy,
	'basketball-v3': policies.SawyerBasketballV2Policy,
	'basketball-traj-v2': policies.SawyerBasketballTrajV2Policy,
	'basketball-45-v2': policies.SawyerBasketballV2Policy,
	'basketball-easy-v2': policies.SawyerBasketballV2Policy,
	'plate-slide-v2': policies.SawyerPlateSlideV2Policy,
	"hand-insert-v2": policies.SawyerHandInsertV2Policy,  
	"peg-insert-side-v2": policies.SawyerPegInsertionSideV2Policy,  
	"peg-insert-side-v3": policies.SawyerPegInsertionSideV2Policy,
	'assembly-v3': policies.SawyerAssemblyV2Policy,
	'pick-place-wall-v3': policies.SawyerPickPlaceWallV2Policy,
	'push-wall-v2': policies.SawyerPushWallV2Policy,
	'soccer-v3': policies.SawyerSoccerV2Policy,
	'disassemble-v2': policies.SawyerDisassembleV2Policy,
	'disassemble-v4': policies.SawyerDisassembleV2Policy,
	'pick-place-wall-v4': policies.SawyerPickPlaceWallV2Policy,
	'pick-place-v4': policies.SawyerPickPlaceV2Policy,
	'pick-place-v3': policies.SawyerPickPlaceV2Policy,
	'push-wall-v4': policies.SawyerPushWallV2Policy,
	'lever-pull-v2': policies.SawyerLeverPullV2Policy,
	'stick-pull-v2': policies.SawyerStickPullV2Policy,
	'stick-pull-v3': policies.SawyerStickPullV2Policy,
	'shelf-place-v2': policies.SawyerShelfPlaceV2Policy,
	'window-close-v2': policies.SawyerWindowCloseV2Policy,
	'reach-v2': policies.SawyerReachV2Policy,
	'button-press-wall-v2': policies.SawyerButtonPressWallV2Policy,
	'box-close-v2': policies.SawyerBoxCloseV2Policy,
	'stick-push-v2': policies.SawyerStickPushV2Policy,
	'handle-pull-v2': policies.SawyerHandlePullV2Policy,
	'door-lock-v2': policies.SawyerDoorLockV2Policy,
}

CAMERA = {
	'hammer-v2': 'corner3',
	'drawer-close-v2': 'corner',
	'drawer-open-v2': 'corner',
	'door-open-v2': 'corner3',
	'bin-picking-v2': 'corner',
	'button-press-topdown-v2': 'corner',
	'door-unlock-v2': 'corner',
	'basketball-v2': 'corner',
	'basketball-v3': 'corner',
	'basketball-traj-v2': 'corner',
	'basketball-45-v2': 'corner',
	'basketball-easy-v2': 'corner',
	'plate-slide-v2': 'corner',
	'hand-insert-v2': 'corner',
	'peg-insert-side-v2': 'corner3',
	'peg-insert-side-v3': 'corner3',
	'assembly-v3': 'corner',
	'pick-place-wall-v3': 'corner3',
	'push-wall-v2': 'corner',
	'soccer-v3': 'corner',
	'disassemble-v2': 'corner',
	'disassemble-v4': 'corner',
	'pick-place-wall-v4': 'corner3',
	'pick-place-v3': 'corner3',
	'pick-place-v4': 'corner3',
	'push-wall-v4': 'corner',
	'lever-pull-v2': 'corner4',
	'stick-pull-v2': 'corner3',
	'stick-pull-v3': 'corner3',
	'shelf-place-v2': 'corner',
	'window-close-v2': 'corner3',
	'reach-v2': 'corner3',
	'button-press-wall-v2': 'corner',
	'box-close-v2': 'corner3',
	'stick-push-v2': 'corner',
	'handle-pull-v2': 'corner3',
	'door-lock-v2': 'corner',
}

NUM_STEPS = {
	'hammer-v2': 125,		
	'drawer-close-v2': 125,
	'drawer-open-v2': 125,
	'door-open-v2': 125,
	'bin-picking-v2': 175,
	'button-press-topdown-v2': 125,
	'door-unlock-v2': 125,
	'basketball-v2': 175,
	'basketball-v3': 175,
	'basketball-45-v2': 45,
	'basketball-traj-v2': 175,
	'basketball-easy-v2': 175,
	'plate-slide-v2': 125,
	'hand-insert-v2': 125,
	'peg-insert-side-v2': 125,
	'peg-insert-side-v3': 150,
	'assembly-v3': 175,
	'pick-place-wall-v3': 175,
	'push-wall-v2': 175,
	'soccer-v3': 125,
	'disassemble-v2': 125,
	'disassemble-v4': 125,
	'pick-place-wall-v4': 175,
	'pick-place-v4': 125,
	'pick-place-v3': 125,
	'push-wall-v4': 175,
	'lever-pull-v2': 125,
	'stick-pull-v2': 175,
	'stick-pull-v3': 175,
	'shelf-place-v2': 175,
	'window-close-v2': 125,
	'reach-v2': 125,
	'button-press-wall-v2': 125,
	'box-close-v2': 175,
	'stick-push-v2': 125,
	'handle-pull-v2': 175,
	'door-lock-v2': 125,
}


for env_name in env_names:
	print(f"Generating demo for: {env_name}")
	# Initialize policy
	policy = POLICY[env_name]()

	# Initialize env
	ml1 = metaworld.MT1(env_name) # Construct the benchmark, sampling tasks
	# ml1 = metaworld.ML1(env_name)
	env = ml1.train_classes[env_name]()  # Create an environment with task `pick_place`

	# Initialize save dir
	save_dir = Path("./demos") / env_name
	save_dir.mkdir(parents=True, exist_ok=True)

	# Initialize video recorder
	video_recorder = VideoRecorder(save_dir, camera_name=CAMERA[env_name])

	images_list = list()
	moco_images_list = list()
	observations_list = list()
	actions_list = list()
	rewards_list = list()

	count = 0
	episode = 0
	while episode < num_demos:
		video_recorder.init(env)
		print(f"Episode {episode}")
		images = list()
		moco_images = list()
		observations = list()
		actions = list()
		rewards = list()
		image_stack = deque([], maxlen=3)
		moco_image_stack = deque([], maxlen=3)
		goal_achieved = 0

		# Set random goal
		task = ml1.train_tasks[count] #random.choice(ml1.train_tasks)
		print(count)
		count += 1
		env.set_task(task)  # Set task

		# Reset env
		observation = env.reset()  # Reset environment
		# video_recorder.record(env)
		num_steps = NUM_STEPS[env_name]
		for step in range(num_steps):
			# Get frames
			pixel = env.render(offscreen=True, camera_name=CAMERA[env_name])
			frame = cv2.resize(pixel.copy(), (84,84))
			frame = np.transpose(frame, (2,0,1))
			image_stack.append(frame)
			while(len(image_stack)<3):
				image_stack.append(frame)
			images.append(np.concatenate(image_stack, axis=0))
			moco_frame = cv2.resize(pixel.copy(), (224,224))
			moco_frame = np.transpose(moco_frame, (2,0,1))
			moco_image_stack.append(moco_frame)
			while(len(moco_image_stack)<3):
				moco_image_stack.append(moco_frame)
			moco_images.append(np.concatenate(moco_image_stack, axis=0))
			# Get action
			action = policy.get_action(observation)
			if 'basketball' in env_name:
				action += np.random.randn(action.shape[0]) * 0.1
			action = np.clip(action, -1.0, 1.0)
			actions.append(action)
			# Get observation
			observation[-3:] = 0
			observations.append(observation)
			# Act in the environment
			observation, reward, done, info = env.step(action)
			rewards.append(reward)
			video_recorder.record(env)
			goal_achieved += info['success'] 

		print(rewards[-1], np.max(rewards))
		if rewards[-1] < 7:
			continue
		# Store trajectory
		episode = episode + 1
		images_list.append(np.array(images))
		moco_images_list.append(np.array(moco_images))
		observations_list.append(np.array(observations))
		actions_list.append(np.array(actions))
		rewards_list.append(np.array(rewards))

		video_recorder.save(f'demo{episode}.mp4')

	file_path = save_dir / 'expert_demos.pkl'
	payload = [images_list, observations_list, actions_list, rewards_list, moco_images_list]
	# payload = [images_list, observations_list, actions_list, rewards_list]


	with open(str(file_path), 'wb') as f:
		pickle.dump(payload, f)

