import copy
import numpy as np
from hgg.gcc_utils import gcc_load_lib, c_double, c_int
import torch
import torch.nn.functional as F
import time

def goal_distance(goal_a, goal_b):
	return np.linalg.norm(goal_a - goal_b, ord=2)
def goal_concat(obs, goal):
	return np.concatenate([obs, goal], axis=0)

class TrajectoryPool:
	def __init__(self, pool_length):
		self.length = pool_length
		self.pool = []
		self.pool_init_state = []
		self.pool_corresponding_state = []
		self.counter = 0

	def insert(self, trajectory, init_state, trajectory_corresponding_state = None):
		if self.counter<self.length:
			self.pool.append(trajectory.copy())
			self.pool_init_state.append(init_state.copy())
			
		else:
			self.pool[self.counter%self.length] = trajectory.copy()
			self.pool_init_state[self.counter%self.length] = init_state.copy()
			
		self.counter += 1

	def pad(self):
		if self.counter>=self.length:
			return copy.deepcopy(self.pool), copy.deepcopy(self.pool_init_state)
		
		
		pool = copy.deepcopy(self.pool)
		pool_init_state = copy.deepcopy(self.pool_init_state)
		
		while len(pool)<self.length:
			pool += copy.deepcopy(self.pool)
			pool_init_state += copy.deepcopy(self.pool_init_state)
		
		return copy.deepcopy(pool[:self.length]), copy.deepcopy(pool_init_state[:self.length])

class MatchSampler:    
	def __init__(self, goal_env, goal_eval_env, env_name, achieved_trajectory_pool, num_episodes,				
				agent = None, max_episode_timesteps =None, split_ratio=0.1, split_type='last', 
				add_noise_to_goal= False, cost_type='d2c', gamma=0.99, hgg_c=3.0, hgg_L=5.0, device = 'cuda', hgg_gcc_path = None,
				goal_condition = False,
				
				):
		# Assume goal env
		self.env = goal_env
		self.eval_env = goal_eval_env
		self.env_name = env_name
		
		self.add_noise_to_goal = add_noise_to_goal
		self.cost_type = cost_type
		self.agent = agent
		
		self.vf = None
		self.critic = None
		self.policy = None

		self.max_episode_timesteps = max_episode_timesteps
		self.split_ratio = split_ratio
		self.split_type = split_type
		
		self.goal_condition = goal_condition
		
		self.gamma = gamma
		self.hgg_c = hgg_c
		self.hgg_L = hgg_L
		self.device = device
  		
		self.success_threshold = {'sawyer_peg_push' : getattr(self.env, 'TARGET_RADIUS', None),
								  'sawyer_peg_pick_and_place' : getattr(self.env, 'TARGET_RADIUS', None),
								  'Point4WayComplexVer2Maze-v0' : 0.5, 
								  'Point4WayFarmlandMaze-v0' : 0.5, 	
								  'Point2WaySpiralMaze-v0' : 0.5,							  
								  'AntMazeComplex2Way-v0' : 1.0, 
								}
		self.loss_function = torch.nn.BCELoss(reduction='none')
		
		
		self.dim = np.prod(self.env.convert_obs_to_dict(self.env.reset())['achieved_goal'].shape)

		self.delta = self.success_threshold[env_name] #self.env.distance_threshold
		self.goal_distance = goal_distance

		self.length = num_episodes # args.episodes
		
		init_goal = self.eval_env.convert_obs_to_dict(self.eval_env.reset())['achieved_goal'].copy()
		
		
		self.pool = np.tile(init_goal[np.newaxis,:],[self.length,1])+np.random.normal(0,self.delta,size=(self.length,self.dim))
		

		self.match_lib = gcc_load_lib(hgg_gcc_path+'/cost_flow.c')
		
		self.achieved_trajectory_pool = achieved_trajectory_pool

		
		self.max_dis = 0
		for i in range(1000):			
			obs_dict = self.env.convert_obs_to_dict(self.env.reset())
			dis = self.goal_distance(obs_dict['achieved_goal'],obs_dict['desired_goal'])
			if dis>self.max_dis: self.max_dis = dis 
	
	def set_networks(self, vf=None, critic=None, policy=None):
		if vf is not None:
			self.vf = vf
		if critic is not None:
			self.critic = critic
		if policy is not None:
			self.policy = policy


	def add_noise(self, pre_goal, noise_std=None):
		goal = pre_goal.copy()		
		if noise_std is None: noise_std = self.delta
		
		if self.env_name in ['sawyer_peg_pick_and_place']:
			noise = np.random.normal(0, noise_std, size=goal.shape[-1])			
			goal += noise
		elif self.env_name in ['sawyer_peg_push']:
			noise = np.random.normal(0, noise_std, size=goal.shape[-1])
			noise[2] = 0	
			goal += noise
			goal[..., -3:] = np.clip(goal[..., -3:], (-0.6, 0.2, 0.0147), (0.6, 1.0, 0.0148))
		elif self.env_name in ['Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
			goal += np.random.normal(0, noise_std, size=2)	
			goal = np.clip(goal, (-18,-18), (18,18))
		elif self.env_name in ['Point2WaySpiralMaze-v0']:
			goal += np.random.normal(0, noise_std, size=2)	
			goal = np.clip(goal, (-14,-18), (14,18))
		elif self.env_name in ['AntMazeComplex2Way-v0']:
			goal += np.random.normal(0, noise_std, size=2)	
			goal = np.clip(goal, (-6,-10), (6,10))
		else:
			raise NotImplementedError

		return goal.copy()

	def sample(self, idx):
		if self.add_noise_to_goal:			
			if self.env_name in ['AntMazeComplex2Way-v0', 'Point2WaySpiralMaze-v0', 'Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
				noise_std = 0.5
			elif self.env_name in ['sawyer_peg_push', 'sawyer_peg_pick_and_place']:
				noise_std = 0.05
			else:
				raise NotImplementedError('Should consider noise scale env by env')
			return self.add_noise(self.pool[idx], noise_std = noise_std)
		else:
			return self.pool[idx].copy()

	def update(self, initial_goals, desired_goals):
		if self.achieved_trajectory_pool.counter==0:
			self.pool = copy.deepcopy(desired_goals)
			return
		
		
		achieved_pool, achieved_pool_init_state = self.achieved_trajectory_pool.pad()

		
		if 'd2c' in self.cost_type:
			if self.split_type=='uniform':
				# uniform split
				achieved_pool = [traj[::int(self.max_episode_timesteps*self.split_ratio)] for traj in achieved_pool] # list of reduced ts			
				raise NotImplementedError
			elif self.split_type=='last':
				# uniform split on last N steps
				if 'd2c' in self.cost_type:
					if 'Point' in self.env_name:
						interval = 3
					else:
						interval = 6
				else:
					if self.env_name in ['AntMazeComplex2Way-v0']:
						interval = 6
					elif self.env_name in ['sawyer_peg_push', 'sawyer_peg_pick_and_place', 'Point2WaySpiralMaze-v0',  'Point4WayComplexVer2Maze-v0','Point4WayFarmlandMaze-v0']:
						interval = 4
					else:
						raise NotImplementedError
				achieved_pool = [np.concatenate([traj[int(-self.split_ratio*self.max_episode_timesteps)::interval], traj[-1:]], axis=0) for traj in achieved_pool] # list of reduced ts			
				
		

		assert len(achieved_pool)>=self.length, 'If not, errors at assert match_count==self.length, e.g. len(achieved_pool)=5, self.length=25, match_count=5'
		candidate_goals = []
		candidate_edges = []
		candidate_id = []
		
		achieved_value = []
		for i in range(len(achieved_pool)):
			 # maybe for all timesteps in an episode
			obs = [goal_concat(achieved_pool_init_state[i], achieved_pool[i][j]) for  j in range(achieved_pool[i].shape[0])] # list of [dim] (len = ts)
			
			with torch.no_grad():				
				obs_t = torch.from_numpy(np.stack(obs, axis =0)).to(self.device) #[ts, dim]
				
				obs_t = obs_t.float()
					
				if self.vf is not None:
					if self.agent.normalize_rl_obs:
						obs_t = self.agent.normalize_obs(obs_t, self.env_name)
				
					value = self.vf(obs_t).detach().cpu().numpy()[:,0]
					if self.agent.rl_reward_type=='d2c':
						if self.agent.d2c_reward_type=='positive':
							# value = np.clip(value, 0, 1.0/(1.0-self.gamma)) # previous wrong ver
							value = -np.clip(value, 0, 1.0/(1.0-self.gamma))
						elif self.agent.d2c_reward_type=='negative':
							# value = np.clip(value, -1.0/(1.0-self.gamma), 0) # previous wrong ver
							value = -np.clip(value, -1.0/(1.0-self.gamma), 0)
					elif self.agent.rl_reward_type=='sparse':
						if self.agent.sparse_reward_type=='positive':
							value = -np.clip(value, 0, 1.0/(1.0-self.gamma)) 
						elif self.agent.sparse_reward_type=='negative':
							value = -np.clip(value, -1.0/(1.0-self.gamma), 0)

				elif self.critic is not None and self.policy is not None:
					if self.agent.normalize_rl_obs:
						obs_t = self.agent.normalize_obs(obs_t, self.env_name)
					
					
					n_sample = 10
					tiled_obs_t = torch.tile(obs_t, (n_sample, 1, 1)).view((-1, obs_t.shape[-1])) #[ts, dim] -> [n_sample*ts, dim]
					dist = self.policy(obs_t) # obs : [ts, dim]
					action = dist.rsample((n_sample,)) # [n_sample, ts, dim]
					action = action.view((-1, action.shape[-1])) # [n_sample*ts, dim]
					actor_Q1, actor_Q2 = self.critic(tiled_obs_t, action)
					actor_Q = torch.min(actor_Q1, actor_Q2).view(n_sample, -1, actor_Q1.shape[-1]) # [n_sample*ts, dim(1)] -> [n_sample, ts, dim(1)] 
					value = torch.mean(actor_Q, dim = 0).detach().cpu().numpy()[:,0] #[ts, dim(1)] -> [ts,]
					
					if self.agent.rl_reward_type=='d2c':
						if self.agent.d2c_reward_type=='positive':
							# value = np.clip(value, 0, 1.0/(1.0-self.gamma)) # previous wrong ver
							value = -np.clip(value, 0, 1.0/(1.0-self.gamma)) 
							
						elif self.agent.d2c_reward_type=='negative':
							# value = np.clip(value, -1.0/(1.0-self.gamma), 0) # previous wrong ver
							value = -np.clip(value, -1.0/(1.0-self.gamma), 0)
					elif self.agent.rl_reward_type=='sparse':
						if self.agent.sparse_reward_type=='positive':
							value = -np.clip(value, 0, 1.0/(1.0-self.gamma)) 
						elif self.agent.sparse_reward_type=='negative':
							value = -np.clip(value, -1.0/(1.0-self.gamma), 0)
				
				elif self.cost_type=='d2c': # d2c only
					pass
				else:
					raise NotImplementedError
				
			if 'd2c' in self.cost_type:
				if 'vf' in self.cost_type:
					achieved_value.append(value.copy())
				else:
					pass
			else:
				raise NotImplementedError


		n = 0
		graph_id = {'achieved':[],'desired':[]}
		for i in range(len(achieved_pool)):
			n += 1
			graph_id['achieved'].append(n)
		for i in range(len(desired_goals)):
			n += 1
			graph_id['desired'].append(n)
		n += 1
		self.match_lib.clear(n)
		
		# value related compute
		
		if 'vf' in self.cost_type:
			# For considering different traj length
			vf_outputs_max = -np.inf
			vf_outputs_min = np.inf
			for i in range(len(achieved_value)): # list of vf_output [ts,]
				if achieved_value[i].max() > vf_outputs_max:
					vf_outputs_max = achieved_value[i].max()
				if achieved_value[i].min() < vf_outputs_min:
					vf_outputs_min = achieved_value[i].min()
			for i in range(len(achieved_value)):
				achieved_value[i] = ((achieved_value[i]-vf_outputs_min)/(vf_outputs_max - vf_outputs_min+0.00001)-0.5)*2 #[0, 1] -> [-1,1]
		


		if self.goal_condition and self.agent.use_d2c:
			assert 'd2c' in self.cost_type and self.agent.d2c_gcrl
			# achieved_pool : list of [ts, dim] (ts could be different)
			achieved_pool_traj_lengths = [traj.shape[0] for traj in achieved_pool] # list of ts			
			
			reshaped_achieved_pool = np.concatenate([traj for traj in achieved_pool], axis =0) # [ts_1+ts_2+ ... , dim]
			start = time.time()
			
			inference_all_together = False # True
			
			if inference_all_together:
				
				reshaped_achieved_pool_gc_list = []
				for dg in desired_goals:
					tiled_dg = np.tile(dg, (reshaped_achieved_pool.shape[0],1)) # [ts_1+ts_2+ ... , dg_dim]
					reshaped_achieved_pool_gc = np.concatenate([reshaped_achieved_pool, tiled_dg], axis=-1) # [ts_1+ts_2+ ... , ag_dim+dg_dim]
					reshaped_achieved_pool_gc_list.append(reshaped_achieved_pool_gc)
				
				# [(ts_1+ts_2+ ... ), num_dg, ag_dim+dg_dim] -> [(ts_1+ts_2+ ... )*num_dg, ag_dim+dg_dim]
				reshaped_achieved_pool_gc = np.stack(reshaped_achieved_pool_gc_list, axis=1).reshape(-1, reshaped_achieved_pool_gc.shape[-1])
				
				# [(ts_1+ts_2+ ... )*num_dg]
				reshaped_classification_probs = self.agent.get_prob_by_d2c(reshaped_achieved_pool_gc)				
				reshaped_classification_probs = reshaped_classification_probs.reshape(-1, len(desired_goals)) # [(ts_1+ts_2+ ... ), num_dg]
				
			else:
				reshaped_classification_probs_list = []
				for dg in desired_goals:					
					tiled_dg = np.tile(dg, (reshaped_achieved_pool.shape[0],1)) # [ts_1+ts_2+ ... , dg_dim]
					reshaped_achieved_pool_gc = np.concatenate([reshaped_achieved_pool, tiled_dg], axis=-1) # [ts_1+ts_2+ ... , ag_dim+dg_dim]
					reshaped_classification_probs_list.append(self.agent.get_prob_by_d2c(reshaped_achieved_pool_gc)) # list of [ts_1+ts_2+ ... ]
				reshaped_classification_probs = np.stack(reshaped_classification_probs_list, axis =1) # [(ts_1+ts_2+ ... ), num_dg]
				
			classification_probs = []
			for idx, length in enumerate(achieved_pool_traj_lengths):
				if idx ==0 :
					start_idx = 0 
					end_idx = length
				else:
					start_idx = end_idx 
					end_idx = start_idx+length
				classification_probs.append(torch.from_numpy(reshaped_classification_probs[start_idx:end_idx]).float().to(self.device)) # list of [ts, num_dg]

		
		else:
			# uncertainty (classification prob) related compute
			if 'd2c' in self.cost_type:
				# achieved_pool : list of [ts, dim] (ts could be different)
				achieved_pool_traj_lengths = [traj.shape[0] for traj in achieved_pool] # list of ts			
				reshaped_achieved_pool = np.concatenate([traj for traj in achieved_pool], axis =0) # [ts_1+ts_2+ ... , dim]
				start = time.time()
				reshaped_classification_probs = self.agent.get_prob_by_d2c(reshaped_achieved_pool)			
				
				classification_probs = []			
				for idx, length in enumerate(achieved_pool_traj_lengths):
					if idx ==0 :
						start_idx = 0 
						end_idx = length
					else:
						start_idx = end_idx 
						end_idx = start_idx+length
					classification_probs.append(torch.from_numpy(reshaped_classification_probs[start_idx:end_idx]).squeeze().float().to(self.device)) # list of [ts, dim(1)] or [ts]


		for i in range(len(achieved_pool)):
			self.match_lib.add(0, graph_id['achieved'][i], 1, 0)
		
		
		if self.goal_condition and self.agent.use_d2c:
			for i in range(len(achieved_pool)):
				for j in range(len(desired_goals)): # args.episodes(=50)
					# assume dg label is 1 (approximately)
					if ('d2c' in self.cost_type):
						labels = torch.ones_like(classification_probs[i][:, j]).to(self.device) # [ts]
						cross_entropy_loss = self.loss_function(classification_probs[i][:, j], labels).detach().cpu().numpy() # [ts]
						if 'vf' in self.cost_type:
							res = cross_entropy_loss - achieved_value[i]/(self.hgg_L/self.max_dis/(1-self.gamma))
						else:
							res = cross_entropy_loss
					else:
						raise NotImplementedError
						
					assert len(res.shape)==1, 'assume res.shape is [ts]'
					match_dis = np.min(res)
					match_idx = np.argmin(res)
					edge = self.match_lib.add(graph_id['achieved'][i], graph_id['desired'][j], 1, c_double(match_dis))
					candidate_goals.append(achieved_pool[i][match_idx])
					candidate_edges.append(edge)
					candidate_id.append(j)
					
			
		else:
			for i in range(len(achieved_pool)):		
				if ('d2c' in self.cost_type):
					labels = torch.ones_like(classification_probs[i]).to(self.device)
					cross_entropy_loss = self.loss_function(classification_probs[i], labels).detach().cpu().numpy()
					if 'vf' in self.cost_type:
						res = cross_entropy_loss - achieved_value[i]/(self.hgg_L/self.max_dis/(1-self.gamma))
					else:
						res = cross_entropy_loss
				
				match_dis = np.min(res)

				for j in range(len(desired_goals)):
					match_idx = np.argmin(res)

					edge = self.match_lib.add(graph_id['achieved'][i], graph_id['desired'][j], 1, c_double(match_dis))
					candidate_goals.append(achieved_pool[i][match_idx])
					candidate_edges.append(edge)
					candidate_id.append(j)
					


		for i in range(len(desired_goals)):
			self.match_lib.add(graph_id['desired'][i], n, 1, 0)

		match_count = self.match_lib.cost_flow(0,n)
		assert match_count==self.length

		explore_goals = [0]*self.length
		
		for i in range(len(candidate_goals)):
			if self.match_lib.check_match(candidate_edges[i])==1:
				explore_goals[candidate_id[i]] = candidate_goals[i].copy()
				
		assert len(explore_goals)==self.length
		self.pool = np.array(explore_goals)
		


class SimpleBipartiteMatching:    
	def __init__(self,  env_name, num_episodes, goal_dim, hgg_gcc_path = None,				
				):
		# Assume goal env
		
		# self.eval_env = goal_eval_env
		self.env_name = env_name
		
		self.length = num_episodes # args.episodes
		# init_goal = self.eval_env.convert_obs_to_dict(self.eval_env.reset())['achieved_goal'].copy()
		# self.pool = np.tile(init_goal[np.newaxis,:],[self.length,1])
		self.pool = np.zeros([self.length, goal_dim])
		

		self.match_lib = gcc_load_lib(hgg_gcc_path+'/cost_flow.c')
		

	def sample(self, idx):	
		return self.pool[idx].copy()

	
	def update(self, achieved_pool, desired_goals):
		# achieved_pool : list of [ts(1), dim] : unit : traj
		# desired_golas : list of [dim] : unit : state



		assert len(achieved_pool)>=self.length, 'If not, errors at assert match_count==self.length, e.g. len(achieved_pool)=5, self.length=25, match_count=5'
		
		candidate_goals = []
		candidate_edges = []
		candidate_id = []
		candidate_costs = []

		n = 0
		graph_id = {'achieved':[],'desired':[]}
		for i in range(len(achieved_pool)):
			n += 1
			graph_id['achieved'].append(n)
		for i in range(len(desired_goals)):
			n += 1
			graph_id['desired'].append(n)
		n += 1
		self.match_lib.clear(n)
		
		

		for i in range(len(achieved_pool)):
			self.match_lib.add(0, graph_id['achieved'][i], 1, 0)
		
	
		for i in range(len(achieved_pool)):		
			assert achieved_pool[i].shape[0]==1, 'assume ts=1'
			for j in range(len(desired_goals)):
				res = np.sqrt(np.sum(np.square(achieved_pool[i]-desired_goals[j]),axis=1))
				match_dis = np.min(res) # +self.goal_distance(achieved_pool[i][0], initial_goals[j])*self.hgg_c
				cost = match_dis
				match_idx = np.argmin(res)

				edge = self.match_lib.add(graph_id['achieved'][i], graph_id['desired'][j], 1, c_double(match_dis))
				candidate_goals.append(achieved_pool[i][match_idx])
				candidate_edges.append(edge)
				candidate_id.append(j)
				candidate_costs.append(cost)



		for i in range(len(desired_goals)):
			self.match_lib.add(graph_id['desired'][i], n, 1, 0)

		match_count = self.match_lib.cost_flow(0,n)
		assert match_count==self.length

		explore_goals = [0]*self.length
		explore_costs = [0]*self.length

		for i in range(len(candidate_goals)):
			if self.match_lib.check_match(candidate_edges[i])==1:
				explore_goals[candidate_id[i]] = candidate_goals[i].copy()
				explore_costs[candidate_id[i]] = candidate_costs[i].copy()
				
		assert len(explore_goals)==self.length
		self.pool = np.array(explore_goals)
		self.total_cost = np.stack(explore_costs)
		