import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd
import matplotlib.pyplot as plt
import pickle

import utils
from agent.encoder import Encoder
from rewarder import optimal_transport_plan, cosine_distance, euclidean_distance
import time
import copy
from moco import MoCo

class Actor(nn.Module):
	def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim, countdown_scale, countdown_noise):
		super().__init__()

		self.countdown_scale = countdown_scale
		self.countdown_noise = countdown_noise
		if countdown_scale is not None:
			repr_dim += 1
		self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
								   nn.LayerNorm(feature_dim), nn.Tanh())

		self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
									nn.ReLU(inplace=True),
									nn.Linear(hidden_dim, hidden_dim),
									nn.ReLU(inplace=True),
									nn.Linear(hidden_dim, action_shape[0]))

		self.apply(utils.weight_init)

	def forward(self, obs, countdown, std):
		if self.countdown_scale is not None:
			if self.countdown_noise:
				countdown += torch.randn_like(countdown)
			obs = torch.cat([obs, countdown * self.countdown_scale], dim=-1)
		h = self.trunk(obs)

		mu = self.policy(h)
		mu = torch.tanh(mu)
		std = torch.ones_like(mu) * std

		dist = utils.TruncatedNormal(mu, std)
		return dist


class Critic(nn.Module):
	def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim, countdown_scale, countdown_noise):
		super().__init__()

		self.countdown_scale = countdown_scale
		self.countdown_noise = countdown_noise
		if countdown_scale is not None:
			repr_dim += 1
		self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
								   nn.LayerNorm(feature_dim), nn.Tanh())

		self.Q1 = nn.Sequential(
			nn.Linear(feature_dim + action_shape[0], hidden_dim),
			nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
			nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

		self.Q2 = nn.Sequential(
			nn.Linear(feature_dim + action_shape[0], hidden_dim),
			nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
			nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))

		self.apply(utils.weight_init)

	def forward(self, obs, action, countdown):
		if self.countdown_scale is not None:
			if self.countdown_noise:
				countdown += torch.randn_like(countdown)
			obs = torch.cat([obs, countdown * self.countdown_scale], dim=-1)
		h = self.trunk(obs)
		h_action = torch.cat([h, action], dim=-1)
		q1 = self.Q1(h_action)
		q2 = self.Q2(h_action)

		return q1, q2


class InverseDynamics(nn.Module):
	def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
		super().__init__()

		self.trunk = nn.Sequential(nn.Linear(repr_dim * 2, feature_dim),
								   nn.LayerNorm(feature_dim), nn.Tanh())

		self.predict = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
									nn.ReLU(inplace=True),
									nn.Linear(hidden_dim, hidden_dim),
									nn.ReLU(inplace=True),
									nn.Linear(hidden_dim, action_shape[0]))

		self.apply(utils.weight_init)

	def forward(self, obs, next_obs):
		x = torch.cat([obs, next_obs], dim=-1)
		h = self.trunk(x)
		action = self.predict(h)
		return action


class POTILAgent:
	def __init__(self, obs_shape, action_shape, device, lr, feature_dim,
				 hidden_dim, critic_target_tau, num_expl_steps,
				 update_every_steps, stddev_schedule, stddev_clip, use_tb, augment,
				 rewards, sinkhorn_rew_scale, update_target_every,
				 auto_rew_scale, auto_rew_scale_factor, suite_name, obs_type, bc_weight_type, bc_weight_schedule, bandwidth, use_trunk, name, timestamp, timestamp_onehot, ot_truncate, use_inverse_dynamics, alpha, use_timeout, update_policy_freq, cost_bias, auto_rew_scale_wo_truncate, expl_stddev, expl_noise, timestamp_dim, countdown_scale, nstep, timeout_mean_reward, countdown_noise, expl_mode, expl_offset):
		self.device = device
		self.lr = lr
		self.critic_target_tau = critic_target_tau
		self.update_every_steps = update_every_steps
		self.use_tb = use_tb
		self.num_expl_steps = num_expl_steps
		self.stddev_schedule = stddev_schedule
		self.stddev_clip = stddev_clip
		self.augment = augment
		self.rewards = rewards
		self.sinkhorn_rew_scale = sinkhorn_rew_scale
		self.update_target_every = update_target_every
		self.auto_rew_scale = auto_rew_scale
		self.auto_rew_scale_factor = auto_rew_scale_factor
		self.use_encoder = True if obs_type=='pixels' else False
		self.bc_weight_type = bc_weight_type
		self.bc_weight_schedule = bc_weight_schedule
		self.bandwidth = bandwidth
		self.use_trunk = use_trunk
		self.last_update_step = 0
		self.timestamp = timestamp
		self.ot_truncate = ot_truncate
		self.use_inverse_dynamics = use_inverse_dynamics
		self.alpha = alpha
		self.use_timeout = use_timeout
		self.obs_type = obs_type
		self.update_policy_freq = update_policy_freq
		self.cost_bias = cost_bias
		self.auto_rew_scale_wo_truncate = auto_rew_scale_wo_truncate
		self.expl_stddev = expl_stddev
		self.timestamp_onehot = timestamp_onehot
		self.expl_noise = expl_noise
		self.timestamp_dim = timestamp_dim
		self.countdown_scale = countdown_scale
		self.nstep = nstep
		self.timeout_mean_reward = timeout_mean_reward
		self.countdown_noise = countdown_noise
		self.expl_mode = expl_mode
		self.expl_offset = expl_offset

		# models
		if self.use_encoder:
			self.encoder = Encoder(obs_shape).to(device)
			self.encoder_target = Encoder(obs_shape).to(device)
			repr_dim = self.encoder.repr_dim
		else:
			repr_dim = obs_shape[0]

		self.trunk_target = nn.Sequential(
			nn.Linear(repr_dim, feature_dim),
			nn.LayerNorm(feature_dim), nn.Tanh()).to(device)

		self.actor = Actor(repr_dim, action_shape, feature_dim,
						   hidden_dim, countdown_scale, countdown_noise).to(device)

		self.critic = Critic(repr_dim, action_shape, feature_dim,
							 hidden_dim, countdown_scale, countdown_noise).to(device)
		self.critic_target = Critic(repr_dim, action_shape,
									feature_dim, hidden_dim, countdown_scale, countdown_noise).to(device)
		self.critic_target.load_state_dict(self.critic.state_dict())

		if self.use_inverse_dynamics:
			raise NotImplementedError
			self.inverse_dynamics = InverseDynamics(repr_dim, action_shape, feature_dim, hidden_dim).to(device)
			self.actor_bc = Actor(repr_dim, action_shape, feature_dim, hidden_dim).to(device)

		# optimizers
		if self.use_encoder:
			self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=lr)
		self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
		self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
		if self.use_inverse_dynamics:
			self.inverse_dynamics_opt = torch.optim.Adam(self.inverse_dynamics.parameters(), lr=lr)
			self.actor_bc_opt = torch.optim.Adam(self.actor_bc.parameters(), lr=lr)

		# data augmentation
		self.aug = utils.RandomShiftsAug(pad=4)

		self.train()
		self.critic_target.train()

	def __repr__(self):
		return "potil"

	def train(self, training=True):
		self.training = training
		if self.use_encoder:
			self.encoder.train(training)
		self.actor.train(training)
		self.critic.train(training)
		if self.use_inverse_dynamics:
			self.inverse_dynamics.train(training)

	def act(self, obs, step, ts, progress, eval_mode, expl_mode=False):
		obs = torch.as_tensor(obs, device=self.device)
		obs = self.encoder(obs.unsqueeze(0)) if self.use_encoder else obs.unsqueeze(0)
		countdown = torch.tensor([[progress - ts]], device=self.device, dtype=obs.dtype)

		stddev = utils.schedule(self.expl_noise, step)
		if self.expl_mode == 'linear':
			low = stddev
			high = max(stddev, self.expl_stddev)
			stddev = high - min(max(0, progress - ts), self.expl_offset) / self.expl_offset * (high - low)
		elif expl_mode:
			stddev = self.expl_stddev

		dist = self.actor(obs, countdown, stddev)

		if eval_mode:
			action = dist.mean
		else:
			action = dist.sample(clip=None)
			if step < self.num_expl_steps:
				action.uniform_(-1.0, 1.0)
		return action.cpu().numpy()[0]

	def update_critic(self, obs, ts, countdown, action, reward, discount, next_obs, timeout, mean_reward, step):
		metrics = dict()

		with torch.no_grad():
			stddev = utils.schedule(self.stddev_schedule, step)
			dist = self.actor(next_obs, countdown - self.nstep, stddev,)
			next_action = dist.sample(clip=self.stddev_clip)
			target_Q1, target_Q2 = self.critic_target(next_obs, next_action, countdown - self.nstep)
			target_V = torch.min(target_Q1, target_Q2)
			# if self.use_inverse_dynamics:
			# 	dist_bc = self.actor_bc(next_obs, stddev)
			# 	next_action_bc = dist.sample(clip=self.stddev_clip)
			# 	target_Q1_bc, target_Q2_bc = self.critic_target(next_obs, next_action_bc)
			# 	target_V_bc = torch.min(target_Q1_bc, target_Q2_bc)
			# 	target_V = torch.max(target_V, target_V_bc)
			if self.use_timeout:
				target_Q = reward + (discount * target_V) * (1 - timeout)
				if self.timeout_mean_reward:
					target_Q += (1 - timeout) * (-ts) * mean_reward
			else:
				target_Q = reward + (discount * target_V)

		Q1, Q2 = self.critic(obs, action, countdown)

		critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)

		# optimize encoder and critic
		if self.use_encoder:
			self.encoder_opt.zero_grad(set_to_none=True)
		self.critic_opt.zero_grad(set_to_none=True)
		critic_loss.backward()
		self.critic_opt.step()
		if self.use_encoder:
			self.encoder_opt.step()

		if self.use_tb:
			metrics['critic_target_q'] = target_Q.mean().item()
			metrics['critic_q1'] = Q1.mean().item()
			metrics['critic_q2'] = Q2.mean().item()
			metrics['critic_loss'] = critic_loss.item()
			
		return metrics

	def update_actor(self, obs, countdown, obs_bc, obs_qfilter, action_bc, bc_regularize, step):
		metrics = dict()

		stddev = utils.schedule(self.stddev_schedule, step)

		dist = self.actor(obs, countdown, stddev)
		action = dist.sample(clip=self.stddev_clip)
		log_prob = dist.log_prob(action).sum(-1, keepdim=True)

		Q1, Q2 = self.critic(obs, action, countdown)
		Q = torch.min(Q1, Q2)

		# Compute bc weight
		if not bc_regularize:
			bc_weight = 0.0
		elif self.bc_weight_type == "linear":
			bc_weight = utils.schedule(self.bc_weight_schedule, step)
		elif self.bc_weight_type == "qfilter":
			"""
			Soft Q-filtering inspired from 			
			Nair, Ashvin, et al. "Overcoming exploration in reinforcement 
			learning with demonstrations." 2018 IEEE international 
			conference on robotics and automation (ICRA). IEEE, 2018.
			"""
			with torch.no_grad():
				stddev = 0.1
				dist_qf = self.actor_bc(obs_qfilter, stddev)
				action_qf = dist_qf.mean
				Q1_qf, Q2_qf = self.critic(obs_qfilter.clone(), action_qf)
				Q_qf = torch.min(Q1_qf, Q2_qf)
				bc_weight = (Q_qf>Q).float().mean().detach()

		actor_loss = - Q.mean() * (1-bc_weight)

		if bc_regularize:
			stddev = 0.1
			dist_bc = self.actor(obs_bc, stddev)
			log_prob_bc = dist_bc.log_prob(action_bc).sum(-1, keepdim=True)
			actor_loss += - log_prob_bc.mean()*bc_weight*0.03

		if self.use_inverse_dynamics:
			alpha = self.alpha * torch.abs(Q).mean().detach()
			action = dist.mean
			with torch.no_grad():
				action_bc = self.actor_bc(obs, 0).mean
				Q1_bc, _ = self.critic(obs, action_bc)
				filter_bc = (Q1_bc > Q1).float()
			bc_loss = ((action - action_bc) ** 2).sum(-1)
			if filter_bc.sum() > 0:
				bc_loss = (filter_bc * bc_loss).sum() / filter_bc.sum()
			else:
				bc_loss = (filter_bc * bc_loss).sum()
			pg_loss = actor_loss
			actor_loss = pg_loss + alpha * bc_loss

		# optimize actor
		self.actor_opt.zero_grad(set_to_none=True)
		actor_loss.backward()
		self.actor_opt.step()
		if self.use_tb:
			metrics['actor_loss'] = actor_loss.item()
			metrics['actor_logprob'] = log_prob.mean().item()
			metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
			metrics['actor_q'] = Q.mean().item()
			if bc_regularize and self.bc_weight_type == "qfilter":
				metrics['actor_qf'] = Q_qf.mean().item()
			metrics['bc_weight'] = bc_weight
			metrics['regularized_rl_loss'] = -Q.mean().item()* (1-bc_weight)
			metrics['rl_loss'] = -Q.mean().item()
			if bc_regularize:
				metrics['regularized_bc_loss'] = - log_prob_bc.mean().item()*bc_weight*0.03
				metrics['bc_loss'] = - log_prob_bc.mean().item()*0.03
			if self.use_inverse_dynamics:
				metrics['pg_loss'] = pg_loss.item()
				metrics['bc_loss'] = (alpha * bc_loss).item()
				metrics['filter_bc'] = filter_bc.mean().item()
		return metrics

	def update_inverse_dynamics(self, obs, action, next_obs, step):
		metrics = dict()

		prediction = self.inverse_dynamics(obs, next_obs)
		inverse_dynamics_loss = F.mse_loss(prediction, action)

		# optimize inverse dynamics model
		self.inverse_dynamics_opt.zero_grad(set_to_none=True)
		inverse_dynamics_loss.backward()
		self.inverse_dynamics_opt.step()
		if self.use_tb:
			metrics['inverse_dynamics_error'] = torch.abs(prediction - action).mean()
			metrics['inverse_dynamics_loss'] = inverse_dynamics_loss.item()
		
		return metrics

	def update_actor_bc(self, expert_replay_iter, step):
		metrics = dict()

		batch = next(expert_replay_iter)
		obs, _, next_obs = utils.to_torch(batch, self.device)
		
		# augment
		if self.use_encoder and self.augment:
			obs = self.aug(obs.float())
			next_obs = self.aug(next_obs.float())
			# encode
			obs = self.encoder(obs)
			next_obs = self.encoder(next_obs)
		else:
			obs = obs.float()
			next_obs = next_obs.float()

		with torch.no_grad():
			action_inv = self.inverse_dynamics(obs, next_obs)
		action_bc = self.actor_bc(obs, 0).mean
		actor_loss = F.mse_loss(action_bc, action_inv)
		
		self.actor_bc_opt.zero_grad(set_to_none=True)
		actor_loss.backward()
		self.actor_bc_opt.step()

		if self.use_tb:
			metrics['actor_bc_loss'] = actor_loss.item()
		return metrics

	def update(self, replay_iter, expert_replay_iter, step, progress, bc_regularize=False):
		metrics = dict()

		if step % self.update_every_steps != 0:
			return metrics

		batch = next(replay_iter)
		obs, action, reward, discount, next_obs, next_one_obs, timeout, ts, mean_reward, _ = utils.to_torch(
			batch, self.device)
		countdown = progress - ts

		# augment
		if self.use_encoder and self.augment:
			obs_qfilter = self.aug(obs.clone().float())
			obs = self.aug(obs.float())
			next_obs = self.aug(next_obs.float())
			next_one_obs = self.aug(next_one_obs.float())
		else:
			obs_qfilter = obs.clone().float()
			obs = obs.float()
			next_obs = next_obs.float()
			next_one_obs = next_one_obs.float()

		if self.use_encoder:
			# encode
			obs = self.encoder(obs)
			with torch.no_grad():
				next_obs = self.encoder(next_obs)
				next_one_obs = self.encoder(next_one_obs)

		if bc_regularize:
			batch = next(expert_replay_iter)
			obs_bc, action_bc = utils.to_torch(batch, self.device)
			# augment
			if self.use_encoder and self.augment:
				obs_bc = self.aug(obs_bc.float())
			else:
				obs_bc = obs_bc.float()
			# encode
			if bc_regularize and self.bc_weight_type=="qfilter":
				obs_qfilter = self.encoder_bc(obs_qfilter) if self.use_encoder else obs_qfilter
				obs_qfilter = obs_qfilter.detach()
			else:
				obs_qfilter = None
			obs_bc = self.encoder(obs_bc) if self.use_encoder else obs_bc 
			# Detach grads
			obs_bc = obs_bc.detach()
		else:
			obs_qfilter = None
			obs_bc = None 
			action_bc = None

		if self.use_tb:
			metrics['batch_reward'] = reward.mean().item()

		# update critic
		metrics.update(self.update_critic(obs, ts, countdown, action, reward, discount, next_obs, timeout, mean_reward, step))

		# update actor
		if step % (self.update_every_steps * self.update_policy_freq) == 0:
			metrics.update(self.update_actor(obs.detach(), countdown, obs_bc, obs_qfilter, action_bc, bc_regularize, step))

		if self.use_inverse_dynamics:
			metrics.update(self.update_inverse_dynamics(obs.detach(), action.detach(), next_one_obs.detach(), step))
			metrics.update(self.update_actor_bc(expert_replay_iter, step))

		# update critic target
		utils.soft_update_params(self.critic, self.critic_target,
								 self.critic_target_tau)

		return metrics

	# def tendency_score(self, cost):
	# 	assert cost.shape[0] == cost.shape[1]
	# 	min_cost = np.ones(cost.shape[0]) * 10000.0
	# 	min_pos = np.zeros(cost.shape[0], dtype=int)
	# 	mark = np.zeros(cost.shape[0], dtype=int)
	# 	score = np.zeros(cost.shape[0])

	# 	count = 1
	# 	mark[0] = 1
	# 	min_cost[0] = cost[0, 0]
	# 	min_pos[0] = 0
	# 	score[0] = 1

	# 	for i in range(1, cost.shape[0]):
	# 		min_pos[i] = np.argmin(cost[i, :i + 1])
	# 		min_cost[i] = cost[i, min_pos[i]]
	# 		if mark[min_pos[i]] == 0:
	# 			count += 1
	# 		mark[min_pos[i]] += 1
	# 		for j in range(0, i):
	# 			if cost[j, i] < cost[j, min_pos[j]]:
	# 				mark[min_pos[j]] -= 1
	# 				if mark[min_pos[j]] == 0:
	# 					count -= 1
	# 				min_pos[j] = i
	# 				min_cost[j] = cost[j, i]
	# 				if mark[min_pos[j]] == 0:
	# 					count += 1
	# 				mark[min_pos[j]] += 1
	# 		score[i] = count
	# 	return score

	def init_demos(self, cfg, demos):
		if self.use_encoder:
			self.reward_encoder = MoCo((9, 224, 224)).to(self.device)
			self.reward_encoder.eval()
			self.reward_encoder = self.reward_encoder.get_feature
			with torch.no_grad():
				self.demos = [self.reward_encoder(torch.tensor(demo).to(self.device)) for demo in demos]
		else:
			self.demos = demos

		if cfg.adaptive_progress_mode == 'transport_plan':
			length = expert_demos[0][::2].shape[0]
			self.score_matrix = torch.zeros((length, length), device=self.device)
			for i in range(length):
				for j in range(length):
					self.score_matrix[i, j] = abs(i - j)

			scores = []
			for i in range(len(expert_demos)):
				for j in range(cfg.num_demos):
					if j != i:
						demo1 = torch.tensor(expert_demos[i][::2]).to(self.device).float()
						demo2 = torch.tensor(expert_demos[j][::2]).to(self.device).float()
						if self.rewards == 'sinkhorn_cosine':
							cost_matrix = cosine_distance(demo1, demo2) 
						elif self.rewards == 'sinkhorn_euclidean':
							cost_matrix = euclidean_distance(demo1, demo2)
						score = np.zeros(length)
						for k in range(length):
							transport_plan = optimal_transport_plan(demo1[:k + 1], demo2[:k + 1], cost_matrix[:k + 1, :k + 1], method='sinkhorn', niter=100).float()
							score[k] = torch.sum(self.score_matrix[:k + 1, :k + 1] * transport_plan)
						scores.append(score)
			scores = np.stack(scores, axis=0)
			self.ref_score = np.mean(scores, axis=0)
			import matplotlib.pyplot as plt
			plt.clf()
			plt.bar(range(self.ref_score.shape[0]), self.ref_score)
			plt.savefig(f'ref_score')
			return

		if cfg.adaptive_progress_mode == 'lis' or cfg.adaptive_progress_mode == 'lis_minus' or cfg.adaptive_discount_mode == 'lis':
			scores = []
			for i in range(len(self.demos)):
				for j in range(cfg.num_demos):
					if j != i:
						demo1 = torch.tensor(self.demos[i][::2]).to(self.device).float()
						demo2 = torch.tensor(self.demos[j][::2]).to(self.device).float()
						if self.rewards == 'sinkhorn_cosine':
							cost_matrix = cosine_distance(demo1, demo2) 
						elif self.rewards == 'sinkhorn_euclidean':
							cost_matrix = euclidean_distance(demo1, demo2)
						score = torch.zeros(cost_matrix.shape[0])
						for k in range(cost_matrix.shape[0]):
							pos = cost_matrix[:k + 1, :k + 1].min(1)[1]
							score[k] = utils.longest_increasing_subsequence(pos)
						scores.append(score)
			scores = np.stack(scores, axis=0)
			self.ref_score = np.percentile(scores, cfg.ref_score_percentile, axis=0)
			import matplotlib.pyplot as plt
			plt.clf()
			plt.bar(range(self.ref_score.shape[0]), self.ref_score)
			plt.savefig(f'ref_score')
			return

		if cfg.adaptive_progress_mode == 'tendency':
			scores = []
			for i in range(len(expert_demos)):
				for j in range(cfg.num_demos):
					if j != i:
						demo1 = torch.tensor(expert_demos[i][::2]).to(self.device).float()
						demo2 = torch.tensor(expert_demos[j][::2]).to(self.device).float()
						if self.rewards == 'sinkhorn_cosine':
							cost_matrix = cosine_distance(demo1, demo2) 
						elif self.rewards == 'sinkhorn_euclidean':
							cost_matrix = euclidean_distance(demo1, demo2)
						scores.append(self.tendency_score(cost_matrix.cpu().numpy()))
			scores = np.stack(scores, axis=0)
			self.ref_score = np.percentile(scores, 10, axis=0)
			import matplotlib.pyplot as plt
			plt.clf()
			plt.bar(range(self.ref_score.shape[0]), self.ref_score)
			plt.savefig(f'ref_score')
			return

		if cfg.adaptive_progress_mode == 'closest_distance':
			demo = torch.tensor(expert_demos[0][::2]).to(self.device).float()
			if self.rewards == 'sinkhorn_cosine':
				cost_matrix = cosine_distance(demo, demo) 
			elif self.rewards == 'sinkhorn_euclidean':
				cost_matrix = euclidean_distance(demo, demo)
			self.ref_distance = np.zeros(demo.shape[0])
			for i in range(1, demo.shape[0]):
				self.ref_distance[i] = np.max(cost_matrix[i, max(0, i - cfg.adaptive_progress_offset):i].cpu().numpy())
			import matplotlib.pyplot as plt
			plt.clf()
			plt.bar(range(self.ref_distance.shape[0]), self.ref_distance)
			plt.savefig(f'ref_distance')
			return
		
		# rewards = []
		# for i in range(len(expert_demos)):
		# 	for j in range(cfg.num_demos):
		# 		if j != i:
		# 			demo1 = torch.tensor(expert_demos[i][::2]).to(self.device).float()
		# 			demo2 = torch.tensor(expert_demos[j][::2]).to(self.device).float()
		# 			if self.rewards == 'sinkhorn_cosine':
		# 				cost_matrix = cosine_distance(demo1, demo2) 
		# 			elif self.rewards == 'sinkhorn_euclidean':
		# 				cost_matrix = euclidean_distance(demo1, demo2)
		# 			for k in range(demo1.shape[0]):
		# 				for l in range(demo2.shape[0]):
		# 					if abs(k - l) > self.bandwidth:
		# 						cost_matrix[k, l] += 1
		# 			transport_plan = optimal_transport_plan(demo1, demo2, cost_matrix, method='sinkhorn', niter=100).float()
		# 			reward = -torch.diag(torch.mm(transport_plan, cost_matrix.T)).detach().cpu().numpy()
		# 			# TODO: cost bias
		# 			rewards.append(reward)
		# if len(rewards) == 0:
		# 	return
		# self.reward_bound = np.stack(rewards, axis=0).mean(axis=0)
		# if cfg.reward_bound_smooth:
		# 	reward_bound_copy = copy.deepcopy(self.reward_bound)
		# 	for i in range(self.reward_bound.shape[0]):
		# 		self.reward_bound[i] = np.mean(reward_bound_copy[max(0, i - 5):min(self.reward_bound.shape[0], i + 6)])
		# import matplotlib.pyplot as plt
		# plt.clf()
		# plt.bar(range(len(self.reward_bound)), self.reward_bound)
		# plt.savefig(f'reward_bound')
		# plt.clf()
		# plt.bar(range(len(self.reward_bound)), self.reward_bound.cumsum())
		# plt.savefig(f'reward_bound_sum')

	def ot_rewarder(self, observations, step, truncate=None, episode=None, return_infos=False, progress=None):

		# if step - self.last_update_step >= self.update_target_every:
		# 	if self.use_encoder:
		# 		self.encoder_target.load_state_dict(self.encoder.state_dict())
		# 	self.trunk_target.load_state_dict(self.actor.trunk.state_dict())
		# 	self.last_update_step = step

		scores_list = list()
		ot_rewards_list = list()
		if return_infos:
			cost_matrix_list = list()
			transport_plan_list = list()
			match_score_list = list()

		obs = torch.tensor(observations).to(self.device).float()
		obs = obs.detach()
		if self.use_encoder:
			with torch.no_grad():
				obs = self.reward_encoder(obs)
		if self.timestamp is not None:
			obs = obs[:, :-1]
		if self.timestamp_onehot:
			obs = obs[:, :-100]
		if self.timestamp_dim is not None:
			obs = obs[:, :-self.timestamp_dim]
		# if self.use_encoder:
		# 	obs = self.encoder_target(obs)
		# if self.use_trunk:
		# 	obs = self.trunk_target(obs)
		episode_len = obs.shape[0]
		if truncate is None and self.ot_truncate is not None:
			truncate = self.ot_truncate
		if truncate is not None:
			truncate = int(truncate * obs.shape[0])
		
		for demo in self.demos:
			exp = torch.tensor(demo).to(self.device).float()
			exp = exp.detach()[::2]
			# if self.use_encoder:
			# 	exp = self.encoder_target(exp)
			# if self.use_trunk:
			# 	exp = self.trunk_target(exp)
			
			if self.rewards == 'sinkhorn_cosine':
				cost_matrix = cosine_distance(
					obs, exp)  # Get cost matrix for samples using critic network.
			elif self.rewards == 'sinkhorn_euclidean':
				cost_matrix = euclidean_distance(
					obs, exp)  # Get cost matrix for samples using critic network.
			# elif self.rewards == 'cosine':
			# 	exp = torch.cat((exp, exp[-1].unsqueeze(0)))
			# 	ot_rewards = -(1. - F.cosine_similarity(obs, exp))
			# 	ot_rewards *= self.sinkhorn_rew_scale
			# 	ot_rewards = ot_rewards.detach().cpu().numpy()
			# elif self.rewards == 'euclidean':
			# 	exp = torch.cat((exp, exp[-1].unsqueeze(0)))
			# 	ot_rewards = -(obs - exp).norm(dim=1)
			# 	ot_rewards *= self.sinkhorn_rew_scale
			# 	ot_rewards = ot_rewards.detach().cpu().numpy()
			else:
				raise NotImplementedError()

			if return_infos:
				cost_matrix_list.append(cost_matrix.detach().cpu().numpy())

			for i in range(obs.shape[0]):
				for j in range(exp.shape[0]):
					if abs(i - j) > self.bandwidth or truncate is not None and (i >= truncate or j >= truncate):
						cost_matrix[i, j] += 1
			transport_plan = optimal_transport_plan(
				obs, exp, cost_matrix, method='sinkhorn',
				niter=100).float()  # Getting optimal coupling
			ot_rewards = -self.sinkhorn_rew_scale * torch.diag(
				torch.mm(transport_plan,
							(cost_matrix - self.cost_bias).T)).detach().cpu().numpy()

			if return_infos:
				transport_plan_list.append(transport_plan.detach().cpu().numpy())
				if 'score_matrix' not in self.__dict__:
					self.score_matrix = torch.zeros((episode_len, episode_len), device=self.device)
					for i in range(episode_len):
						for j in range(episode_len):
							self.score_matrix[i, j] = abs(i - j)

			# if truncate is not None:
			# 	ot_rewards = np.concatenate([ot_rewards, np.zeros(episode_len - ot_rewards.shape[0])])
			if progress is None:
				scores_list.append(np.sum(ot_rewards) / self.sinkhorn_rew_scale)
			else:
				# scores_list.append(np.sum(ot_rewards[:int(progress * (obs.shape[0] - 1)) + 1]) / self.sinkhorn_rew_scale)
				scores_list.append(np.sum(ot_rewards[:int(progress * obs.shape[0])]) / self.sinkhorn_rew_scale)
			if truncate is not None:
				ot_rewards[truncate:] = float('nan')
			ot_rewards_list.append(ot_rewards)

		closest_demo_index = np.argmax(scores_list)
		if return_infos:
			return cost_matrix_list[closest_demo_index], transport_plan_list[closest_demo_index], ot_rewards_list[closest_demo_index], np.max(scores_list)
		return ot_rewards_list[closest_demo_index]

	def eval_inverse_dynamics(self, expert_demo, expert_action, obs, action, global_frame):
		errors = []
		with torch.no_grad():
			obs = torch.tensor(obs).to(self.device).float()
			action = torch.tensor(action).to(self.device).float()
			prediction = self.inverse_dynamics(obs[:-1], obs[1:])
			error = torch.abs(prediction - action[:-1]).mean(-1)
			plt.clf()
			plt.bar(range(len(error)), error.cpu().numpy())
			plt.savefig(f'figures/agent_inverse_dynamics_error_{global_frame}')

			for i in range(len(expert_demo)):
				obs = torch.tensor(expert_demo[i][::2]).to(self.device).float()
				action = torch.tensor(expert_action[i][::2]).to(self.device).float()
				prediction = self.inverse_dynamics(obs[:-1], obs[1:])
				error = torch.abs(prediction - action[:-1]).mean(-1)
				errors += error.tolist()
				if i == 0:
					plt.clf()
					plt.bar(range(len(error)), error.cpu().numpy())
					plt.savefig(f'figures/expert_inverse_dynamics_error_{global_frame}')
		return np.mean(errors)

	def save_snapshot(self):
		keys_to_save = ['actor', 'critic', 'actor_opt', 'critic_opt', 'sinkhorn_rew_scale']
		if self.use_encoder:
			keys_to_save += ['encoder', 'encoder_opt']
		if self.use_inverse_dynamics:
			keys_to_save += ['actor_bc', 'actor_bc_opt']
			keys_to_save += ['inverse_dynamics', 'inverse_dynamics_opt']
		payload = {k: self.__dict__[k] for k in keys_to_save}
		return payload

	def load_snapshot(self, payload):
		for k, v in payload.items():
			self.__dict__[k] = v
		self.critic_target.load_state_dict(self.critic.state_dict())
		if self.use_encoder:
			self.encoder_target.load_state_dict(self.encoder.state_dict())
		# self.trunk_target.load_state_dict(self.actor.trunk.state_dict())

		if not self.use_inverse_dynamics and self.bc_weight_type == "qfilter":
			# Store a copy of the BC policy with frozen weights
			if self.use_encoder:
				self.encoder_bc = copy.deepcopy(self.encoder)
				for param in self.encoder_bc.parameters():
					param.requires_grad = False
			self.actor_bc = copy.deepcopy(self.actor)
			for param in self.actor_bc.parameters():
				param.required_grad = False

		# Update optimizers
		if self.use_encoder:
			self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=self.lr)
			self.encoder_opt.load_state_dict(payload['encoder_opt'].state_dict())
		self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=self.lr)
		self.actor_opt.load_state_dict(payload['actor_opt'].state_dict())
		self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=self.lr)
		self.critic_opt.load_state_dict(payload['critic_opt'].state_dict())
		if self.use_inverse_dynamics:
			self.inverse_dynamics_opt = torch.optim.Adam(self.inverse_dynamics.parameters(), lr=self.lr)
			self.inverse_dynamics_opt.load_state_dict(payload['inverse_dynamics_opt'].state_dict())
			self.actor_bc_opt = torch.optim.Adam(self.actor_bc.parameters(), lr=self.lr)
			self.actor_bc_opt.load_state_dict(payload['actor_bc_opt'].state_dict())

		print('sinkhorn_rew_scale of the loaded model:', payload['sinkhorn_rew_scale'])
