#!/usr/bin/env python3

import warnings
import os

os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'
from pathlib import Path

import hydra
import numpy as np
import torch
from dm_env import specs
import math

import utils
from logger import Logger
from replay_buffer import ReplayBufferStorage, make_replay_loader, make_expert_replay_loader
from video import TrainVideoRecorder, VideoRecorder
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from rewarder import optimal_transport_plan

warnings.filterwarnings('ignore', category=DeprecationWarning)
torch.backends.cudnn.benchmark = True

def make_agent(obs_spec, action_spec, cfg):
	cfg.obs_shape = obs_spec[cfg.obs_type].shape
	cfg.action_shape = action_spec.shape
	return hydra.utils.instantiate(cfg)

class WorkspaceIL:
	def __init__(self, cfg):
		self.work_dir = Path.cwd()
		print(f'workspace: {self.work_dir}')

		self.cfg = cfg
		utils.set_seed_everywhere(cfg.seed)
		self.device = torch.device(cfg.device)
		self.setup()

		self.agent = make_agent(self.train_env.observation_spec(),
								self.train_env.action_spec(), cfg.agent)

		if repr(self.agent) == 'drqv2':
			self.cfg.suite.num_train_frames = self.cfg.num_train_frames_drq
		if repr(self.agent) == 'bc':
			self.cfg.suite.num_train_frames = self.cfg.num_train_frames_bc
			self.cfg.suite.num_seed_frames = 0

		self.expert_replay_loader = make_expert_replay_loader(
			self.cfg.expert_dataset, self.cfg.batch_size // 2, self.cfg.num_demos, self.cfg.obs_type)
		self.expert_replay_iter = iter(self.expert_replay_loader)
			
		self.timer = utils.Timer()
		self._global_step = 0
		self._global_episode = 0

		with open(self.cfg.expert_dataset, 'rb') as f:
			data = pickle.load(f)
			if len(data) == 5:
				self.expert_pixel = data[4]
				self.expert_pixel = self.expert_pixel[:self.cfg.num_demos]
				self.expert_pixel = [pixel[:self.env_horizon] for pixel in self.expert_pixel]
				data = data[:4]
			if self.cfg.obs_type == 'pixels':
				self.expert_demo, _, self.expert_action, self.expert_reward = data
			elif self.cfg.obs_type == 'features':
				_, self.expert_demo, self.expert_action, self.expert_reward = data
		self.expert_demo = self.expert_demo[:self.cfg.num_demos]
		self.expert_demo = [demo[:self.env_horizon] for demo in self.expert_demo]
		self.expert_action = self.expert_action[:self.cfg.num_demos]
		self.expert_reward = np.mean(self.expert_reward[:self.cfg.num_demos])
		if self.cfg.obs_type == 'pixels':
			self.agent.init_demos(self.cfg, self.expert_pixel)
		else:
			self.agent.init_demos(self.cfg, self.expert_demo)
		
	def setup(self):
		# create logger
		self.logger = Logger(self.work_dir, use_tb=self.cfg.use_tb)
		# create envs
		self.train_env, self.env_horizon = hydra.utils.call(self.cfg.suite.task_make_fn)
		self.eval_env, self.env_horizon = hydra.utils.call(self.cfg.suite.task_make_fn)

		# create replay buffer
		data_specs = [
			self.train_env.observation_spec()[self.cfg.obs_type],
			self.train_env.action_spec(),
			specs.Array((1, ), np.float32, 'reward'),
			specs.Array((1, ), np.float32, 'discount')
		]

		self.replay_storage = ReplayBufferStorage(data_specs,
												  self.work_dir / 'buffer',
												  self.cfg.adaptive_truncate, self.cfg.adaptive_truncate_alpha, self.cfg.adaptive_truncate_offset, self.cfg.buffer_truncate_by_progress, self.cfg.adaptive_discount)

		self.replay_loader = make_replay_loader(
			self.work_dir / 'buffer', self.cfg.replay_buffer_size,
			self.cfg.batch_size, self.cfg.replay_buffer_num_workers,
			self.cfg.suite.save_snapshot, self.cfg.nstep, self.cfg.suite.discount, self.cfg.progress_guide or self.cfg.adaptive_progress, self.cfg.progress_truncate, self.cfg.buffer_truncate_by_progress, self.cfg.biased_sampling, self.cfg.mask_reward_by_progress, self.cfg.oversample_timeout, self.cfg.agent.use_timeout, self.cfg.mask_reward_by_progress_offset, self.cfg.sample_by_length)
			# TODO: seperate progress_guide and adaptive_progress
		assert self.cfg.mask_reward_by_progress is None or self.cfg.agent.use_timeout == False
		assert self.cfg.agent.use_timeout == True or self.cfg.oversample_timeout is None

		self._replay_iter = None
		self.expert_replay_iter = None

		self.video_recorder = VideoRecorder(
			self.work_dir if self.cfg.save_video else None)
		self.train_video_recorder = TrainVideoRecorder(
			self.work_dir if self.cfg.save_train_video else None)

	@property
	def global_step(self):
		return self._global_step

	@property
	def global_episode(self):
		return self._global_episode

	@property
	def global_frame(self):
		return self.global_step * self.cfg.suite.action_repeat

	@property
	def replay_iter(self):
		if self._replay_iter is None:
			self._replay_iter = iter(self.replay_loader)
		return self._replay_iter

	def eval(self):
		step, episode, total_reward = 0, 0, 0
		eval_until_episode = utils.Until(self.cfg.suite.num_eval_episodes)

		if self.cfg.suite.name == 'openaigym' or self.cfg.suite.name == 'metaworld':
			paths = []
			obj_ups = []
			obj_moves = []
		scores = []
		costs = []
		while eval_until_episode(episode):
			if self.cfg.suite.name == 'metaworld':
				path = []
				obj_up = []
				obj_move = []
			time_step = self.eval_env.reset()
			observations = []
			pixels = []
			actions = []
			self.video_recorder.init(self.eval_env, enabled=(episode == 0))
			episode_step = 0
			while not time_step.last():
				with torch.no_grad(), utils.eval_mode(self.agent):
					action = self.agent.act(time_step.observation[self.cfg.obs_type], self.global_step, episode_step, int(self.progress * ((self.env_horizon + 1) // 2)), eval_mode=True)
				observations.append(time_step.observation[self.cfg.obs_type])
				pixels.append(time_step.observation['pixels_large'])
				actions.append(action)
				time_step = self.eval_env.step(action)
				if self.cfg.suite.name == 'metaworld':
					path.append(time_step.observation['goal_achieved'])
					obj_up.append(time_step.observation['obj_up'])
					obj_move.append(time_step.observation['obj_move'])
				self.video_recorder.record(self.eval_env)
				total_reward += time_step.reward
				step += 1
				episode_step += 1

			episode += 1
			self.video_recorder.save(f'{self.global_frame}.mp4')
			if self.cfg.suite.name == 'openaigym':
				paths.append(time_step.observation['goal_achieved'])
			elif self.cfg.suite.name == 'metaworld':
				paths.append(1 if np.sum(path)>3 else 0)
				obj_ups.append(1 if np.sum(obj_up)>10 else 0)
				obj_moves.append(1 if np.sum(obj_move)>10 else 0)

			if repr(self.agent) == 'potil':
				# get infos
				observations = np.stack(observations, 0)
				pixels = np.stack(pixels, 0)
				if self.cfg.obs_type == 'features':
					reward_obs = observations
				else:
					reward_obs = pixels
				actions = np.stack(actions, 0)
				if self.cfg.ot_truncate_by_progress:
					cost_matrix, transport_plan, ot_rewards, score = self.agent.ot_rewarder(reward_obs, self.global_step, truncate=self.progress, return_infos=True, progress=self.progress)
				else:
					cost_matrix, transport_plan, ot_rewards, score = self.agent.ot_rewarder(reward_obs, self.global_step, return_infos=True, progress=self.progress)
				scores.append(score)
				costs.append(cost_matrix)
				if episode == 1:
					save_dir = self.work_dir / 'figures'
					save_dir.mkdir(exist_ok=True)
					
					plt.figure(figsize=(7.5, 7.5))
					sns.heatmap(data=cost_matrix, cmap=plt.get_cmap('Greens'))
					plt.savefig(f'figures/cost_matrix_{self.global_frame}')

					plt.figure(figsize=(7.5, 7.5))
					sns.heatmap(data=transport_plan, cmap=plt.get_cmap('Blues'))
					plt.savefig(f'figures/transport_plan_{self.global_frame}')

					plt.clf()
					plt.bar(range(len(ot_rewards)), ot_rewards / self.agent.sinkhorn_rew_scale)
					plt.savefig(f'figures/ot_reward_{self.global_frame}')

					obs = self.expert_demo[0]
					if self.cfg.suite.timestamp is not None:
						obs = np.concatenate([obs, np.expand_dims(np.arange(obs.shape[0]), -1) * self.cfg.suite.timestamp], axis=-1)
					if self.cfg.suite.timestamp_onehot:
						obs = np.concatenate([obs, np.eye(obs.shape[0], 100)], axis=-1)
					if self.cfg.suite.timestamp_dim is not None:
						pos = np.arange(self.env_horizon + 1)[:, np.newaxis]
						i = np.arange(self.cfg.suite.timestamp_dim)[np.newaxis, :]
						angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(self.cfg.suite.timestamp_dim))
						angle_rads = pos * angle_rates
						sines = np.sin(angle_rads[:, 0::2])
						cosines = np.cos(angle_rads[:, 1::2])
						pos_encoding = np.concatenate([sines, cosines], axis=-1)
						obs = np.concatenate([obs, pos_encoding[:obs.shape[0]]], axis=-1).astype(np.float32)
					obs = obs[::2]
					obs = torch.tensor(obs).to(self.device).float()
					if self.cfg.obs_type == 'pixels':
						obs = self.agent.encoder(obs)
					countdown = int(self.progress * ((self.env_horizon + 1) // 2)) - torch.arange(1, obs.shape[0] + 1).unsqueeze(-1).to(self.device).float()
					with torch.no_grad():
						action = self.agent.actor(obs, countdown, 0).mean
						value = self.agent.critic(obs, action, countdown)[0].cpu().numpy()
					value = value.squeeze(-1)
					plt.clf()
					plt.bar(range(len(value)), value)
					plt.savefig(f'figures/value_{self.global_frame}')

					obs = observations
					obs = torch.tensor(obs).to(self.device).float()
					if self.cfg.obs_type == 'pixels':
						obs = self.agent.encoder(obs)
					countdown = int(self.progress * ((self.env_horizon + 1) // 2)) - torch.arange(1, obs.shape[0] + 1).unsqueeze(-1).to(self.device).float()
					with torch.no_grad():
						action = self.agent.actor(obs, countdown, 0).mean
						value = self.agent.critic(obs, action, countdown)[0].cpu().numpy()
					value = value.squeeze(-1)
					plt.clf()
					plt.bar(range(len(value)), value)
					plt.savefig(f'figures/agent_value_{self.global_frame}')

					plt.clf()
					plt.bar(range(len(value)), self.ot_rewards[max(self._global_episode - 500, 0):self._global_episode + 1, :len(value)].mean(0) / self.agent.sinkhorn_rew_scale)
					plt.savefig(f'figures/recent_ot_rewards_{self.global_frame}')

					if self.cfg.task_name == 'pick-place-wall-v3' or self.cfg.task_name == 'pick-place-wall-v4' or self.cfg.task_name == 'bin-picking-v2':
						for i in range(3):
							if self.count[i] > 0:
								obs = self.best_traj[i]
								obs = torch.tensor(obs).to(self.device).float()
								if self.cfg.obs_type == 'pixels':
									obs = self.agent.encoder(obs)
								countdown = int(self.progress * ((self.env_horizon + 1) // 2)) - torch.arange(1, obs.shape[0] + 1).unsqueeze(-1).to(self.device).float()
								with torch.no_grad():
									action = self.agent.actor(obs, countdown, 0).mean
									value = self.agent.critic(obs, action, countdown)[0].cpu().numpy()
								value = value.squeeze(-1) / self.agent.auto_rew_scale
								plt.clf()
								plt.bar(range(len(value)), value)
								plt.savefig(f'figures/type_{i}_best_traj_value_{self.global_frame}')

								plt.clf()
								plt.bar(range(len(self.best_traj_reward[i])), self.best_traj_reward[i])
								plt.savefig(f'figures/type_{i}_best_traj_reward_{self.global_frame}')
			elif repr(self.agent) == 'dac':
				observations = np.stack(observations, 0)
				pixels = np.stack(pixels, 0)
				if self.cfg.obs_type == 'features':
					reward_obs = observations
				else:
					reward_obs = pixels
				costs += self.agent.compute_cost_matrixs(reward_obs)

		if repr(self.agent) == 'potil' and self.cfg.agent.use_inverse_dynamics:
			inverse_dynamics_error = self.agent.eval_inverse_dynamics(self.expert_demo, self.expert_action, observations, actions, self.global_frame)

		if self.cfg.adaptive_progress or self.cfg.adaptive_discount and self.cfg.adaptive_discount != 'exp':
			if self.cfg.adaptive_progress:
				mode = self.cfg.adaptive_progress_mode
			else:
				mode = self.cfg.adaptive_discount_mode
			if mode == 'lis' or mode == 'lis_minus':
				for i in range(self.cfg.max_progress_delta):
					l = max(1, int(self.progress * costs[0].shape[0]) - self.cfg.adaptive_progress_offset)
					match_scores = []
					for cost in costs:
						pos = cost[:l, :l].argmin(1)
						match_scores.append(utils.longest_increasing_subsequence(pos))
					match_score = np.percentile(match_scores, self.cfg.agent_score_percentile)
					if self.cfg.adaptive_progress or self.cfg.adaptive_discount:
						ref_score = self.agent.ref_score[l - 1]
						# print(l, match_score, self.agent.ref_score[l - 1])
						if self.progress < 1:
							if mode == 'lis':
								flag = match_score >= int(self.cfg.adaptive_progress_threshold * ref_score)
							else:
								flag = match_score >= ref_score - self.cfg.adaptive_progress_threshold
							if flag:
								self.progress += 0.01
								last_update_progress = self.global_frame
								self.replay_storage.update_parameters({'progress': self.progress})
								if self.cfg.adaptive_discount:
									self.discount = math.exp(math.log(self.cfg.adaptive_discount_paras) / int(self.progress * costs[0].shape[0]))
									self.replay_storage.update_parameters({'_discount': self.discount})
						else:
							break
					else:
						break
			elif self.cfg.adaptive_progress and self.cfg.adaptive_progress_mode == 'transport_plan':
				raise NotImplementedError
				while self.progress < 1:
					l = int(self.progress * costs[0].shape[0])
					match_scores = []
					for cost in costs:
						cost = torch.tensor(cost, device=self.device)
						transport_plan = optimal_transport_plan(cost[0, :l], cost[0, :l], cost[:l, :l], method='sinkhorn', niter=100).float()
						match_scores.append(torch.sum(self.agent.score_matrix[:l, :l] * transport_plan).item())
					match_score = np.mean(match_scores)
					print(l, match_score, self.agent.ref_score[l - 1])
					if match_score <= self.cfg.adaptive_progress_threshold * self.agent.ref_score[l - 1]:
						self.progress += 0.01
						last_update_progress = self.global_frame
						self.replay_storage.update_parameters({'progress': self.progress})
					else:
						break

		with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
			log('episode_reward', total_reward / episode)
			log('episode_length', step * self.cfg.suite.action_repeat / episode)
			log('episode', self.global_episode)
			log('step', self.global_step)
			if repr(self.agent) != 'drqv2':
				log('expert_reward', self.expert_reward)
			if self.cfg.suite.name == 'openaigym' or self.cfg.suite.name == 'metaworld':
				log("success_percentage", np.mean(paths))
				log("obj_up", np.mean(obj_ups))
				log("obj_move", np.mean(obj_moves))

			if repr(self.agent) == 'potil' and self.cfg.agent.use_inverse_dynamics:
				log('expert_inverse_dynamics_error', inverse_dynamics_error)

			log('imitation_score', np.mean(scores))

			# if self.cfg.adaptive_progress_mode == 'lis' or self.cfg.adaptive_progress_mode == 'lis_minus':
			if self.cfg.adaptive_progress or self.cfg.adaptive_discount and self.cfg.adaptive_discount_mode == 'lis':
				log('match_score', match_score)
				if self.cfg.adaptive_progress or self.cfg.adaptive_discount:
					log('ref_score', ref_score)
		
		if self.cfg.save_every_model:
			save_dir = self.work_dir / 'models'
			save_dir.mkdir(exist_ok=True)
			self.save_snapshot(save_dir / f'snapshot{self.global_frame}.pt')

	def train_il(self):
		# predicates
		train_until_step = utils.Until(self.cfg.suite.num_train_frames,
									   self.cfg.suite.action_repeat)
		seed_until_step = utils.Until(self.cfg.suite.num_seed_frames,
									  self.cfg.suite.action_repeat)
		eval_every_step = utils.Every(self.cfg.suite.eval_every_frames,
									  self.cfg.suite.action_repeat)

		episode_step, episode_reward = 0, 0
		if self.cfg.progress_truncate is not None:
			self.progress = self.cfg.progress_truncate
		elif self.cfg.progress_guide or self.cfg.adaptive_progress:
			self.progress = self.cfg.progress_guide_start
			last_update_progress = 0
		else:
			self.progress = 1.0
		self.replay_storage.update_parameters({'_discount': self.cfg.suite.discount, 'progress': self.progress})
		if self.cfg.adaptive_discount:
			if self.cfg.adaptive_discount_mode == 'exp':
				self.discount = self.cfg.adaptive_discount_paras[0]
			elif self.cfg.adaptive_discount_mode == 'lis':
				self.progress = self.cfg.progress_guide_start
				self.discount = math.exp(math.log(self.cfg.adaptive_discount_paras) / int(self.progress * ((self.env_horizon + 1) // 2)))
			else:
				raise NotImplementedError
			self.replay_storage.update_parameters({'_discount': self.discount, 'progress': self.progress})
		self.ot_rewards = np.zeros((50000, 100), dtype=np.float32)
		self.closest_distances = np.zeros((50000, 100), dtype=np.float32)
		self.scores = np.zeros((50000, 100), dtype=np.float32)

		time_steps = list()
		observations = list()
		pixels = list()
		actions = list()

		time_step = self.train_env.reset()
		time_steps.append(time_step)
		# observations.append(time_step.observation[self.cfg.obs_type])
		actions.append(time_step.action)
		
		if repr(self.agent) == 'potil':
			if self.agent.auto_rew_scale:
				self.agent.sinkhorn_rew_scale = 1.  # Set after first episode

		if self.cfg.task_name == 'pick-place-wall-v3' or self.cfg.task_name == 'pick-place-wall-v4' or self.cfg.task_name == 'bin-picking-v2':
			self.best_reward = [-10000, -10000, -10000]
			self.ema_reward = [0, 0, 0]
			self.count = [0, 0, 0]
			self.best_traj = [None, None, None]
			self.best_traj_reward = [None, None, None]

		self.train_video_recorder.init(time_step.observation['pixels'])
		metrics = None
		while train_until_step(self.global_step):
			if time_step.last():
				self._global_episode += 1
				if self._global_episode % 1 == 0:
					self.train_video_recorder.save(f'{self.global_frame}.mp4')
				# wait until all the metrics schema is populated
				observations = np.stack(observations, 0)
				pixels = np.stack(pixels, 0)
				if self.cfg.obs_type == 'features':
					reward_obs = observations
				else:
					reward_obs = pixels
				actions = np.stack(actions, 0)
				if repr(self.agent) == 'potil':
					if self.cfg.ot_truncate_by_progress:
						cost_matrix, _, new_rewards, _ = self.agent.ot_rewarder(reward_obs, self.global_step, truncate=self.progress, return_infos=True, progress=self.progress)
					else:
						cost_matrix, _, new_rewards, _ = self.agent.ot_rewarder(reward_obs, self.global_step, return_infos=True, progress=self.progress)
					new_rewards_sum = np.sum(new_rewards)
				elif repr(self.agent) == 'dac':
					new_rewards = self.agent.dac_rewarder(observations, actions)
					new_rewards_sum = np.sum(new_rewards)
				
				if repr(self.agent) == 'potil':
					if self.agent.auto_rew_scale: 
						if self._global_episode == 1:
							if self.agent.auto_rew_scale_wo_truncate:
								new_rewards = self.agent.ot_rewarder(reward_obs, self.global_step, truncate=1.0)
								new_rewards_sum = np.sum(new_rewards)
							self.agent.sinkhorn_rew_scale = self.agent.sinkhorn_rew_scale * self.agent.auto_rew_scale_factor / float(np.abs(new_rewards_sum))
							new_rewards = self.agent.ot_rewarder(reward_obs, self.global_step) #TODO:fix
							new_rewards_sum = np.sum(new_rewards)

				for i, elt in enumerate(time_steps):
					elt = elt._replace(
						observation=time_steps[i].observation[self.cfg.obs_type])
					if repr(self.agent) == 'potil' or repr(self.agent) == 'dac':
						if i == 0:
							elt = elt._replace(reward=float('nan'))
						else:
							elt = elt._replace(reward=new_rewards[i - 1])
					self.replay_storage.add(elt)

				if repr(self.agent) == 'potil' and (self.cfg.task_name == 'pick-place-wall-v3' or self.cfg.task_name == 'pick-place-wall-v4' or self.cfg.task_name == 'bin-picking-v2'):
					length = int(self.progress * observations.shape[0])
					obs = observations[:length]
					reward = new_rewards[:length] / self.agent.sinkhorn_rew_scale
					reward_sum = reward.sum(-1)
					if np.max(obs[:, 6]) > 0.1:
						t = 0
					elif np.max(np.abs(obs[:, 4] - obs[0, 4]) + np.abs(obs[:, 5] - obs[0, 5])) > 0.01:
						t = 1
					else:
						t = 2
					if self.count[t] == 0:
						self.ema_reward[t] = reward_sum
					else:
						self.ema_reward[t] = self.ema_reward[t] * 0.99 + reward_sum * 0.01
					self.count[t] += 1
					if reward_sum > self.best_reward[t]:
						self.best_reward[t] = reward_sum
						self.best_traj[t] = obs
						self.best_traj_reward[t] = reward

				if self.cfg.progress_guide:
					if type(self.cfg.progress_guide_len) == int:
						self.progress = min(self.cfg.progress_guide_start + (self.cfg.progress_guide_end - self.cfg.progress_guide_start) * self.global_frame / self.cfg.progress_guide_len, 1.0)
					else:
						for schedule in self.cfg.progress_guide_len:
							if self.global_frame >= schedule[2] and self.global_frame <= schedule[3]:
								self.progress = min(schedule[0] + (schedule[1] - schedule[0]) * (self.global_frame - schedule[2]) / (schedule[3] - schedule[2]), 1.0)
					self.replay_storage.update_parameters({'progress': self.progress})
				if self.cfg.adaptive_discount and self.cfg.adaptive_discount_mode == 'exp':
					start = math.log(self.cfg.adaptive_discount_paras[0])
					end = math.log(self.cfg.adaptive_discount_paras[1])
					current = start + (end - start) * min(self.global_frame / self.cfg.adaptive_discount_paras[2], 1.0)
					self.discount = math.exp(current)
					self.replay_storage.update_parameters({'_discount': self.discount})

				# while self.cfg.adaptive_progress and int(self.progress * new_rewards.shape[0]) <= self.cfg.adaptive_progress_offset:
				# 	self.progress += 0.01
				# if self.cfg.adaptive_progress and (self.cfg.adaptive_progress_mode == 'transport_plan' or self.cfg.adaptive_progress_mode == 'lis'):
				# 	pass
				# 	# implemented in eval
				# elif self.cfg.adaptive_progress and self.cfg.adaptive_progress_mode == 'tendency':
				# 	self.scores[self._global_episode - 1, :cost_matrix.shape[0]] = self.agent.tendency_score(cost_matrix)
				# 	l = int(self.progress * cost_matrix.shape[0]) - 1 - self.cfg.adaptive_progress_offset
				# 	score = np.median(self.scores[max(self._global_episode - self.cfg.adaptive_progress_num_traj, 0):self._global_episode, l])
				# 	ref_score = self.agent.ref_score[l]
				# 	if self._global_episode > self.cfg.adaptive_progress_num_traj and self.progress < 1:
				# 		if score >= self.cfg.adaptive_progress_threshold * ref_score:
				# 			self.progress += 0.01
				# 			last_update_progress = self.global_frame
				# 			self.replay_storage.update_parameters({'progress': self.progress})
				# elif self.cfg.adaptive_progress and self.cfg.adaptive_progress_mode == 'closest_distance':
				# 	raise NotImplementedError
				# 	self.closest_distances[self._global_episode - 1, :cost_matrix.shape[1]] = cost_matrix.min(0)
				# 	l = min(int(self.progress * new_rewards.shape[0]), cost_matrix.shape[1] - 1)
				# 	closest_dis = np.median(self.closest_distances[:self.global_episode, l]) #wrong warning!!!!!!!!
				# 	ref_dis = self.agent.ref_distance[l]
				# 	if self._global_episode > 500 and self.progress < 1 and closest_dis <= ref_dis:
				# 		self.progress += 0.01
				# 		last_update_progress = self.global_frame
				# 		self.replay_storage.update_parameters({'progress': self.progress})
				# elif self.cfg.adaptive_progress:
				# 	raise NotImplementedError
				# 	self.ot_rewards[self._global_episode, :new_rewards.shape[0]] = new_rewards
				# 	if self.cfg.adaptive_progress_mode == 'fixed':
				# 		raise NotImplementedError
				# 	elif self.cfg.adaptive_progress_mode == 'expert':
				# 		raise NotImplementedError
				# 	elif self.cfg.adaptive_progress_mode == 'expert_sum':
				# 		# progress_reward_ratio = {}
				# 		# for i in [0, 5, 10, 15, 20]:
				# 		# 	ratios = self.ot_rewards[max(self._global_episode - 500, 0):self._global_episode + 1, :max(1, l - i)].sum(-1) / self.agent.sinkhorn_rew_scale / self.agent.reward_bound[:max(1, l - i)].sum(-1)
				# 		# 	for j in [10, 20, 30, 40, 50]:
				# 		# 		progress_reward_ratio[f'p-{i}_{j}%_reward_ratio'] = np.percentile(ratios, j)
				# 		l = min(int(self.progress * new_rewards.shape[0]) - self.cfg.adaptive_progress_offset, self.agent.reward_bound.shape[0])
				# 		l = max(1, l - self.cfg.adaptive_progress_offset)
				# 		ratios = self.ot_rewards[max(self._global_episode - 500, 0):self._global_episode + 1, :l].sum(-1) / self.agent.sinkhorn_rew_scale / self.agent.reward_bound[:l].sum(-1)
				# 		progress_reward_ratio = np.percentile(ratios, self.cfg.adaptive_progress_percentile)
				# 		if self._global_episode > 500 and self.cfg.adaptive_progress and self.progress < 1 and progress_reward_ratio < self.cfg.adaptive_progress_threshold:
				# 			self.progress += 0.01
				# 			last_update_progress = self.global_frame
				# 			self.replay_storage.update_parameters({'progress': self.progress})
				# 	else:
				# 		raise NotImplementedError
				# 	if self.cfg.adaptive_progress and self.global_frame - last_update_progress > 400000:
				# 		last_update_progress = self.global_frame
				# 		self.cfg.adaptive_progress_threshold += 0.2

				if metrics is not None:
					# log stats
					elapsed_time, total_time = self.timer.reset()
					episode_frame = episode_step * self.cfg.suite.action_repeat
					with self.logger.log_and_dump_ctx(self.global_frame,
													  ty='train') as log:
						log('fps', episode_frame / elapsed_time)
						log('total_time', total_time)
						log('episode_reward', episode_reward)
						log('episode_length', episode_frame)
						log('episode', self.global_episode)
						log('buffer_size', len(self.replay_storage))
						log('step', self.global_step)
						if repr(self.agent) == 'potil' or repr(self.agent) == 'dac':
								log('expert_reward', self.expert_reward)
								log('imitation_reward', new_rewards_sum)
						if self.cfg.progress_guide or self.cfg.adaptive_progress or self.cfg.adaptive_discount and self.cfg.adaptive_discount_mode == 'lis':
							log('progress', self.progress)
							log('adaptive_progress_threshold', self.cfg.adaptive_progress_threshold)
							if self.cfg.adaptive_progress_mode == 'transport_plan' or self.cfg.adaptive_progress_mode == 'lis':
								pass
							elif self.cfg.adaptive_progress_mode == 'tendency':
								log('tendency_score', score)
								log('ref_score', ref_score)
							elif self.cfg.adaptive_progress_mode == 'closest_distance':
								log('closest_dis', closest_dis)
								log('ref_dis', ref_dis)
							# else:
							# 	log('progress_reward_ratio', progress_reward_ratio)
						if self.cfg.adaptive_discount:
							log('discount', self.discount)
							log('discount_log', math.log(self.discount))
						
						if repr(self.agent) == 'potil' and (self.cfg.task_name == 'pick-place-wall-v3' or self.cfg.task_name == 'pick-place-wall-v4' or self.cfg.task_name == 'bin-picking-v2'):
							for i in range(3):
								log(f'type_{i}_count', self.count[i])
								log(f'type_{i}_best_reward', self.best_reward[i])
								log(f'type_{i}_ema_reward', self.ema_reward[i])

				# reset env
				time_steps = list()
				observations = list()
				pixels = list()
				actions = list()

				time_step = self.train_env.reset()
				time_steps.append(time_step)
				# observations.append(time_step.observation[self.cfg.obs_type])
				actions.append(time_step.action)
				self.train_video_recorder.init(time_step.observation['pixels'])
				# try to save snapshot
				if self.cfg.suite.save_snapshot:
					self.save_snapshot()
				episode_step = 0
				episode_reward = 0

			# try to evaluate
			if eval_every_step(self.global_step):
				self.logger.log('eval_total_time', self.timer.total_time(),
								self.global_frame)
				self.eval()
				
			# sample action
			with torch.no_grad(), utils.eval_mode(self.agent):
				if episode_step >= math.ceil(self.env_horizon / self.cfg.suite.action_repeat) * self.progress - self.cfg.expl_mode_offset:
					expl_mode = True
				else:
					expl_mode = False
				if (self.cfg.adaptive_progress or self.cfg.progress_guide) and self.progress == 1 and self.agent.expl_stddev is not None and self.global_step % 5000 == 0:
					self.agent.expl_stddev = max(self.agent.expl_stddev - 0.01, self.agent.expl_noise)
				action = self.agent.act(time_step.observation[self.cfg.obs_type], self.global_step, episode_step, int(self.progress * ((self.env_horizon + 1) // 2)), eval_mode=False, expl_mode=expl_mode)

			# try to update the agent
			if not seed_until_step(self.global_step):
				# Update
				metrics = self.agent.update(self.replay_iter, self.expert_replay_iter, 
											self.global_step, int(self.progress * ((self.env_horizon + 1) // 2)), self.cfg.bc_regularize)
				self.logger.log_metrics(metrics, self.global_frame, ty='train')

			# take env step
			time_step = self.train_env.step(action)
			episode_reward += time_step.reward

			time_steps.append(time_step)
			observations.append(time_step.observation[self.cfg.obs_type])
			pixels.append(time_step.observation['pixels_large'])
			actions.append(time_step.action)

			self.train_video_recorder.record(time_step.observation['pixels'])
			episode_step += 1
			self._global_step += 1

	def save_snapshot(self, save_dir=None):
		snapshot = self.work_dir / 'snapshot.pt'
		if save_dir is not None:
			snapshot = save_dir
		keys_to_save = ['timer', '_global_step', '_global_episode']
		payload = {k: self.__dict__[k] for k in keys_to_save}
		payload.update(self.agent.save_snapshot())
		if self.cfg.biased_sampling and self.global_frame > self.cfg.suite.num_seed_frames:
			_, _, _, _, _, _, _, _, _, count = next(self.replay_iter)
			payload['buffer_idxs_count'] = count[-1]
		with snapshot.open('wb') as f:
			torch.save(payload, f)

	def load_snapshot(self, snapshot):
		with snapshot.open('rb') as f:
			payload = torch.load(f)
		agent_payload = {}
		for k, v in payload.items():
			if k == 'buffer_idxs_count':
				self.replay_loader = make_replay_loader(
					self.work_dir / 'buffer', self.cfg.replay_buffer_size,
					self.cfg.batch_size, self.cfg.replay_buffer_num_workers,
					self.cfg.suite.save_snapshot, self.cfg.nstep, self.cfg.suite.discount, self.cfg.progress_guide or self.cfg.adaptive_progress, self.cfg.progress_truncate, self.cfg.buffer_truncate_by_progress, self.cfg.biased_sampling, self.cfg.mask_reward_by_progress, self.cfg.oversample_timeout, self.cfg.agent.use_timeout, self.cfg.mask_reward_by_progress_offset, self.cfg.sample_by_length, idxs_count=v)
				self._replay_iter = None
			elif k not in self.__dict__:
				agent_payload[k] = v
		self.agent.load_snapshot(agent_payload)

	def test(self):
		import matplotlib.pyplot as plt
		obs = self.expert_demo[0][::2]
		obs = torch.tensor(obs).to(self.device).float()
		with torch.no_grad():
			action = self.agent.actor(obs, 0).mean
			value = self.agent.critic(obs, action)[0].cpu().numpy()
		value = value.squeeze(-1)
		print(value)
		plt.clf()
		plt.bar(range(len(value)), value)
		plt.savefig(f'value')
		self.eval()
		return

		# step, episode = 0, 0
		# success_count, fail_count = 0, 0
		# while episode < 1:
		# 	if self.cfg.suite.name == 'metaworld' or self.cfg.suite.name == 'd4rl' or self.cfg.suite.name == 'antmaze' or self.cfg.suite.name == 'adroit':
		# 		path = []
		# 	time_step = self.eval_env.reset()
		# 	self.video_recorder.init(self.eval_env, enabled=True)
			
		# 	total_reward = 0
		# 	observations = list()
		# 	states = list()
		# 	actions = list()
		# 	observations.append(time_step.observation[self.cfg.obs_type])
		# 	states.append(time_step.observation['features'])
		# 	actions.append(time_step.action)
		# 	while not time_step.last():
		# 		with torch.no_grad(), utils.eval_mode(self.agent):
		# 			action = self.agent.act(time_step.observation[self.cfg.obs_type],
		# 									self.global_step,
		# 									eval_mode=True)
		# 		time_step = self.eval_env.step(action)
		# 		if self.cfg.suite.name == 'metaworld' or self.cfg.suite.name == 'd4rl' or self.cfg.suite.name == 'antmaze' or self.cfg.suite.name == 'adroit':
		# 			path.append(time_step.observation['goal_achieved'])
		# 		self.video_recorder.record(self.eval_env)
		# 		total_reward += time_step.reward
		# 		step += 1
		# 		observations.append(time_step.observation[self.cfg.obs_type])
		# 		states.append(time_step.observation['features'])

		# 	episode += 1
		# 	print('episode', episode, ': success=', 1 if np.sum(path)>3 else 0, 'goal_achieved_count=', np.sum(path), 'real_total_reward=', total_reward)
		# 	observations = np.stack(observations, 0)
		# 	self.video_recorder.save(f'run{episode}.mp4')
		# 	if np.sum(path) > 5:
		# 		success_count += 1

		# 	import matplotlib.pyplot as plt
		# 	rewards = self.agent.ot_rewarder(observations, self.expert_demo, self.global_step, episode=episode)
		# 	print(rewards)
		# 	plt.clf()
		# 	plt.bar(range(len(rewards)), rewards)
		# 	plt.ylim(-2, 0)
		# 	plt.savefig(f'reward {episode}')
		# 	rewards = self.agent.ot_rewarder(observations, self.expert_demo, self.global_step, truncate=26)
		# 	print(rewards)
		# 	import matplotlib.pyplot as plt
		# 	plt.clf()
		# 	plt.bar(range(len(rewards)), rewards)
		# 	plt.ylim(-2, 0)
		# 	plt.savefig(f'reward truncate {episode}')
		# print(success_count)

@hydra.main(config_path='cfgs', config_name='config')
def main(cfg):
	from train import WorkspaceIL as W
	root_dir = Path.cwd()
	workspace = W(cfg)
	
	# Load weights
	if cfg.load_bc:
		snapshot = Path(cfg.bc_weight)
		if snapshot.exists():
			print(f'resuming bc: {snapshot}')
			workspace.load_snapshot(snapshot)

	if cfg.test_debug:
		workspace.test()
		return

	workspace.train_il()

	# remove *.npz files
	if not cfg.suite.save_snapshot:
		remove_dir = workspace.work_dir / 'buffer'
		for fn in remove_dir.glob('*.npz'):
			os.remove(fn)


if __name__ == '__main__':
	main()
