from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import pprint
from dotmap import DotMap
from MBExperiment import MBExperiment
from MPC import MPC
from config import create_config
import env # We run this so that the env is registered
import torch
import random
import tensorflow as tf
from config.cartpole import CartpoleConfigModule
import gym
from time import localtime, strftime
from dotmap import DotMap
from Agent import Agent
from DotmapUtils import get_required_argument
from utils import to_json, read_json, to_pickle, read_pickle
from tqdm import trange
import matplotlib
import numpy as np
import pandas as pd
import sys
import os
import copy
import pickle
import utils
import time
from scipy.stats import wasserstein_distance
TORCH_DEVICE = utils.TORCH_DEVICE
VARIATION_NOISE = 1e-12

def set_global_seeds(seed):
	torch.manual_seed(seed)
	if torch.cuda.is_available():
		torch.cuda.manual_seed_all(seed)

	np.random.seed(seed)
	random.seed(seed)

	tf.set_random_seed(seed)

def print_gpu_memory():
	print('Total memory: %.3f GB. Memory allocated: %.5f GB. Max allocated: %.5f GB ' % (torch.cuda.get_device_properties(0).total_memory / 1e9,
														  torch.cuda.memory_allocated(0) / 1e9, torch.cuda.max_memory_allocated() / 1e9 ))
print_gpu_memory()


if __name__ == "__main__":

	parser = argparse.ArgumentParser()
	parser.add_argument('-env', type=str, required=True,
						help='Environment name: select from [cartpole, reacher, pusher, halfcheetah]')
	parser.add_argument('-ca', '--ctrl_arg', action='append', nargs=2, default=[],
						help='Controller arguments, see https://github.com/kchua/handful-of-trials#controller-arguments')
	parser.add_argument('-o', '--override', action='append', nargs=2, default=[],
						help='Override default parameters, see https://github.com/kchua/handful-of-trials#overrides')
	parser.add_argument('-logdir', type=str, default='log',
						help='Directory to which results will be logged (default: ./log)')

	# enables
	parser.add_argument('--METHOD', type=str, required = True, help="Experiment method: UARF, BICHO, or BASELINE")
	parser.add_argument('--MANEUVER', type=str, default = 'straight')

	# misc
	parser.add_argument('--SAVE', action='store_true')
	parser.add_argument('--LOAD', type=str)


	parser.add_argument('--COLD_START_STEPS', default=0, type=float)
	parser.add_argument('--NEW_DATA_TRAIN_THRESHOLD', default=0.01, type=float)
	parser.add_argument('--MAX_BUFFER_LENGTH', default=None, type=float)

	## FUT
	parser.add_argument('--S_FUT_LASTEPS', default=20, type=int)

	# KL
	parser.add_argument('--S_FUT_KL_CST', default=32, type=float)
	parser.add_argument('--S_FUT_E_INPUT', default="cost", type=str, help='[obs, cost]')

	### cmd_line_args to
	args = parser.parse_args()

	output_path = os.path.join(args.logdir, "")
	os.makedirs(output_path, exist_ok=True)
	output_path_by_episode = os.path.join(args.logdir, "episodes")
	os.makedirs(output_path_by_episode, exist_ok=True)


	''' Enable '''

	S_FUT_LASTEPS = args.S_FUT_LASTEPS

	METHOD = args.METHOD

	# KL FUT
	S_FUT_KL_CST = args.S_FUT_KL_CST
	S_FUT_E_INPUT = args.S_FUT_E_INPUT
	COLD_START_STEPS = args.COLD_START_STEPS
	NEW_DATA_TRAIN_THRESHOLD = args.NEW_DATA_TRAIN_THRESHOLD
	MAX_PREDICTION_DISTANCE = 1
	MAX_BUFFER_LENGTH = args.MAX_BUFFER_LENGTH
	env = args.env

	ctrl_type = 'MPC'
	ctrl_args = []
	overrides = []
	logdir = os.path.join(args.logdir, "buffers")
	os.makedirs(logdir, exist_ok=True)
	overrides = args.override


	ctrl_args = DotMap(**{key: val for (key, val) in ctrl_args})
	cfg = create_config(env, ctrl_type, ctrl_args, overrides, logdir)
	cfg.pprint()

	assert ctrl_type == 'MPC'


	# overwrites
	cfg.ctrl_cfg.per = 10
	cfg.exp_cfg.exp_cfg.policy = MPC(cfg.ctrl_cfg)
	exp = MBExperiment(cfg.exp_cfg)

	with open(os.path.join(exp.logdir, "config.txt"), "w") as f:
		f.write(pprint.pformat(cfg.toDict()))

	maneuver_index = -1
	self = exp
	if env == "pointbot":
		'''
		c1, c1_i # 100
		c14, c14_i, 100
		straight, # 100
		sector_1, # 290 steps
		sector_1_i, # 290 steps
		chicane, # 50 steps
		chicane_i # 50 steps
		'''
		maneuver_lengths = {
			"c1" :			100,
			"c1_i" :		  100,
			"chicane" :	   50,
			"chicane_i" :	 50,
			"c14" :		   100,
			"c14_i" :		 100,
			"straight" :	  100,
			"sector_1" :	  290,
			"sector_1_i" :	290,
			"full_track_barcelona": 1200

		}
		self.task_hor = maneuver_lengths[args.MANEUVER]
		self.env.set_maneuver(args.MANEUVER, TORCH_DEVICE)
		maneuvers = list(maneuver_lengths.keys())
		maneuver_index = maneuvers.index(args.MANEUVER)


	# # Exp
	print("*"  * 80)
	print("### Starting")
	print("output_path", output_path)
	print("env", env)
	print("task hor:", self.task_hor)
	print("self.ntrain_iters:", self.ntrain_iters)

	print("Enables:")

	print("FUT:")
	print("S_FUT_LASTEPS", S_FUT_LASTEPS)
	print('S_FUT_KL_CST', S_FUT_KL_CST)
	print("*"  * 80)

	os.makedirs(self.logdir, exist_ok=True)
	self.logdir

	'''
	New version: simplified
	'''
	def act_dyn(self, obs, t, recalculate=True, get_pred_cost=False):
		"""Returns the action that this controller would take at time t given observation obs.

		Arguments:
			obs: The current observation
			t: The current timestep
			get_pred_cost: If True, returns the predicted cost for the action sequence found by
				the internal optimizer.

		Returns: An action (and possibly the predicted cost)
		"""

		#print("# step %d" % t)
		if not self.has_been_trained:
			return np.random.uniform(self.ac_lb, self.ac_ub, self.ac_lb.shape)

		#print("# Shoud re calculate")
		# set the current observation to use it later in the compile_cost function
		#self.sy_cur_obs = obs

		if recalculate:
			self.sy_cur_obs = obs

			# returns array (horizon,) -> with the best actions found
			#   start with the previous means (solutions) and always the initial variance
			recalculate_time = time.time()
			soln = self.optimizer.obtain_solution(self.prev_sol, self.init_var)
			print(f"recalculation time: {time.time() - recalculate_time}")
			'''
			if get_traj:
				obs, var = get_traj(init_obs, soln)
				traj = sorn, obs, var (25, 26, 26)
			'''
			# dU -> actions dimension
			#
			self.prev_sol_full = np.copy(soln) # (25,)
			self.prev_sol = np.copy(soln) # (25,)
			self.prev_sol_obs = np.copy(obs)

		#action = np.copy( self.prev_sol[0] )
		self.prev_sol_with_action = self.prev_sol.copy()

		action = self.prev_sol[:1 * self.dU].reshape(-1, self.dU)  # gives back one action (1,1) -> first element in the queue
		self.prev_sol = np.concatenate([np.copy(self.prev_sol)[1 * self.dU:], np.zeros(1 * self.dU)]) # set last to zero
		return action

	def project_trajectory(self, obs, actions, start_index):
		if not self.has_been_trained:
			return np.random.uniform(self.ac_lb, self.ac_ub, self.ac_lb.shape)

		self.sy_cur_obs = obs

		actions = np.concatenate([np.copy(actions)[start_index * self.dU:], np.zeros(start_index * self.dU)]) # set last to zero

		# get predictions
		a = np.copy(actions).reshape(1,-1)
		cost, traj = self._compile_cost(a, return_seq=True)
		obs_mean = np.array( traj['traj_cur_obs'] ).mean(axis=1)  # (25,1) , the mean is to collapse the 20 particles
		obs_std  = np.array( traj['traj_cur_obs'] ).std(axis=1)	# (25,1) , the mean is to collapse the 20 particles
		cost_mean  = np.array( traj['traj_next_cost'] ).mean(axis=1)  # (25,1) , the mean is to collapse the 20 particles
		cost_std  = np.array( traj['traj_next_cost'] ).std(axis=1)  # (25,1) , the mean is to collapse the 20 particles
		costs_raw = np.array( traj['traj_next_cost'] )
		return (obs_mean.copy(), obs_std.copy(), cost_mean.copy(), cost_std.copy(), costs_raw.copy())

	#
	#  load model
	#
	if env=="pointbot":
		self.initialize_model(maneuver_index)
	else:
		self.initialize_model(-1)

	# model_path = r'results\experiments\CP_200\pre_train\2020-12-01--133433'
	# self.load_model(model_path)

	# print("Validation step...")
	# ret = sample(self.agent, self.task_hor, self.policy, record_fname=None, render=False)
	# Get stats to do the scalling
	if len( self.policy.train_in ) < 500:
		train_in = self.policy.train_in.copy()
		train_targs = self.policy.train_targs.copy()
	else:
		train_in = self.policy.train_in[500:].copy()
		train_targs = self.policy.train_targs[500:].copy()

	train_targs_mu = np.mean(train_targs, axis=0, keepdims=True)
	train_targs_sigma = np.std(train_targs, axis=0, keepdims=True)
	train_targs_sigma[train_targs_sigma < 1e-12] = 1.0

	def sample_P(self, horizon, verbose=False):
		run_start = time.time()

		policy = self.policy

		times, rewards = [], []
		errors = []
		O, A, reward_sum, done = [self.env.reset()], [], 0, False

		policy.reset()
		episode_info = []
		skip_step = 0 # counter since recalculation
		skip_t = 0

		#for t in range(10):
		for t in range(horizon):
			recalculate = False
			if METHOD != "UARF" or len(policy.train_in) < policy.cold_start_steps:
				add_to_buffer = True
			else:
				add_to_buffer = False
			if (t==0):
				recalculate = True
			else:
				if kl_loss > S_FUT_KL_CST:
					recalculate = True
					add_to_buffer = True
					print('# kl skip', S_FUT_KL_CST, kl_loss)

			if skip_t >= (MAX_PREDICTION_DISTANCE-1):
				recalculate = True

			if (skip_t >= (self.policy.plan_hor - 2)):
				""" If we used all the planned actions we need to recalculate """
				print('recalculate because of plan hor %d' % self.policy.plan_hor)
				recalculate = True

			if METHOD == 'BASELINE':
				''' disable skip steps by recalculating always '''
				recalculate = True
				add_to_buffer = True

			a_t = act_dyn(policy, O[t], t, recalculate=recalculate)
			A.append(a_t)

			if recalculate:
				obs_when_recalc = np.copy( O[t] )
				skip_t = 0
			else:
				skip_t += 1

			obs, reward, done, info = self.env.step(A[t])
			O.append(obs.copy())
			reward_sum += reward
			rewards.append(reward)

			a = np.copy(policy.prev_sol_full)
			obs_mean, obs_std, cost_mean, cost_std, cost_raw= project_trajectory(policy, obs_when_recalc, a, start_index=0)
			cost_pred = cost_mean.copy()
			# Get rest of the trajectory from the current obs and the rest of the A
			a = np.copy(policy.prev_sol_full)
			obs_mean_fut, obs_std_fut, cost_mean_fut, cost_std_fut, cost_raw_fut = project_trajectory(policy, O[t+0], a, start_index=skip_t+0)

			next_obs = torch.tensor( obs.reshape(1,-1), device=TORCH_DEVICE )
			cur_acs = torch.tensor( a_t.reshape(1,-1), device=TORCH_DEVICE )
			cost = policy.obs_cost_fn(next_obs) + policy.ac_cost_fn(cur_acs)
			cost = cost.detach().cpu().numpy()

			pred_obs = obs_mean[skip_t+1].copy()
			pred_obs_std = obs_std[skip_t+1].copy()

			# calculate error
			e =(pred_obs[:] - obs[:])  #/ train_targs_sigma[0][1:] # skip possition and normalize

			error_obs_mean = np.mean(e)
			error_obs_euc_d = np.sqrt( np.sum(np.square(e)))


			if S_FUT_E_INPUT == 'obs':
				obs_mean = obs_mean[skip_t+1:][0:S_FUT_LASTEPS]
				obs_std =  obs_std[skip_t+1:][0:S_FUT_LASTEPS]

				obs_mean_fut = obs_mean_fut[skip_t+0:][0:S_FUT_LASTEPS]
				obs_std_fut = obs_std_fut[skip_t+0:][0:S_FUT_LASTEPS]

				# crop when we are reaching the end of the trajectory
				crop = min(obs_mean.shape[0], S_FUT_LASTEPS)
				obs_mean_fut = obs_mean_fut[:crop]
				obs_std_fut = obs_std_fut[:crop]

				dist_1 = torch.distributions.Normal(torch.tensor(obs_mean),
													   torch.tensor(obs_std))

				dist_2 = torch.distributions.Normal(torch.tensor(obs_mean_fut),
													torch.tensor(obs_std_fut) )

				kl_loss = torch.distributions.kl_divergence(dist_1, dist_2)#.mean(axis=1)
				kl_loss = kl_loss.mean().numpy()
			elif  S_FUT_E_INPUT == 'cost':
				cost_mean = cost_mean[skip_t+0:][0:S_FUT_LASTEPS]
				cost_std =  cost_std[skip_t+0:][0:S_FUT_LASTEPS] + VARIATION_NOISE
				cost_raw = cost_raw[skip_t:skip_t + S_FUT_LASTEPS]

				cost_mean_fut = cost_mean_fut[skip_t:][0:S_FUT_LASTEPS]
				cost_std_fut = cost_std_fut[skip_t:][0:S_FUT_LASTEPS]
				cost_raw_fut = cost_raw_fut[skip_t:skip_t + S_FUT_LASTEPS]
				crop = min(cost_mean.shape[0], S_FUT_LASTEPS)
				cost_raw = cost_raw[:crop]
				cost_raw_fut = cost_raw_fut[:crop]
				total_wasserstein = 0
				for step in range(len(cost_raw)):
					total_wasserstein += wasserstein_distance(cost_raw[step], cost_raw_fut[step])
				mean_wasserstein = total_wasserstein/len(cost_raw)
				print("W cost: ", mean_wasserstein)
				# crop when we are reaching the end of the trajectory
				cost_mean_fut = cost_mean_fut[:crop]
				cost_std_fut = cost_std_fut[:crop] + VARIATION_NOISE
				dist_1 = torch.distributions.Normal(torch.tensor(cost_mean),
													   torch.tensor(cost_std))
				try:
					dist_2 = torch.distributions.Normal(torch.tensor(cost_mean_fut),
													torch.tensor(cost_std_fut) )
				except ValueError:
					print(cost_mean_fut, cost_std_fut)
				#kl_loss = torch.distributions.kl_divergence(dist_1, dist_2)#.mean(axis=1)
				#kl_loss = kl_loss.mean().numpy()
				kl_loss = mean_wasserstein
				#print('cost', kl_loss)

				if np.isnan(kl_loss):
					print(cost_mean)
					print(cost_std)
					print(cost_mean_fut)
					print(cost_std_fut)
					kl_loss = 1e9
			else:
				kl_loss = 0


			#train_set_euc_dist_error_mean, train_set_euc_dist_error_std = e_disc.mean(), e_disc.std()
			r = {'error_obs_mean': error_obs_mean.copy(),
				 'pred_cost': cost_pred[skip_t],
				 'sim_cost':cost[0],
				 'error_obs_euc_d':error_obs_euc_d,
				 'recalculate':float(recalculate),
				 'add_to_buffer':float(add_to_buffer),
				 'run_wall_time':time.time() - run_start,
				 'kl_loss':kl_loss,
				 'skip_t':skip_t
				}
			r.update( {"sim_obs_%.2d" % f:obs[f] for f in range( obs.shape[0] )} )
			r.update( {"pred_obs_%.2d" % f:pred_obs[f] for f in range( pred_obs.shape[0] )} )

			if verbose:
				print("r_t:%.2f" % reward, "reward_sum:%.2f" % reward_sum, "%.2d %.2d" % (t, skip_t), 
                 "adding to buffer: %d" % int(add_to_buffer), "erros_obs_euc: %.4f" % (error_obs_euc_d))

			episode_info.append(r.copy())
			if done: break

		df = pd.DataFrame(episode_info)
		porc_recalc = (np.sum( df.recalculate ) / len( df.recalculate))

		ret = {
			"obs": np.array(O)[:-1],
			"obs_": np.array(O)[1:],
			"ac": np.array(A).reshape(-1,self.policy.dU),
			"reward_sum": reward_sum,
			"rewards": np.array(rewards),
			"porc_recalc":porc_recalc,
			"add_to_buffer":df["add_to_buffer"].values,
			"recalculated":df["recalculate"].values
		}
		print(ret['reward_sum'], porc_recalc)
		return ret

	''' skip steps '''
	#N_STEPS = cfg.ctrl_cfg.per
	self.policy.cold_start_steps = int(COLD_START_STEPS) if COLD_START_STEPS is not None else None
	self.policy.max_buffer_length = int(MAX_BUFFER_LENGTH) if MAX_BUFFER_LENGTH is not None else None
	self.policy.new_data_train_threshold = NEW_DATA_TRAIN_THRESHOLD
	self.policy.method = METHOD
	''' Steps alike '''
	S_ALIKE_C =2.0

	verbose = True
	ep_stats = []
	all_traj = []
	horizon = self.task_hor
	run_path = output_path

	if args.LOAD:
		print("LOADING\n\n\n\n\n\n\n")
		self.load_model(args.LOAD, "29")
		self.policy.has_been_trained = True

	for i in trange(self.ntrain_iters):
		samples = []
		episode_start_time = time.time()
		for j in range(max(self.neval, self.nrollouts_per_iter)):
			s = sample_P(self, horizon=horizon, verbose=verbose)
			samples.append(s)
			all_traj.append(s.copy())

		collection_time = time.time() - episode_start_time
		ep_reward = [sample["reward_sum"] for sample in samples[:self.neval]][0]
		ep_perc_recalc = [sample["porc_recalc"] for sample in samples[:self.neval]][0]
		ep_step_rewards = [sample["rewards"] for sample in samples[:self.neval]][0]
		ep_recalculated = [sample["recalculated"] for sample in samples[:self.neval]][0]
		ep_added_to_buffer = [sample["add_to_buffer"] for sample in samples[:self.neval]][0]


		samples = samples[:self.nrollouts_per_iter]

		print("Start training episode %d" % i)

		# Train
		losses = []
		train_time = time.time()
		def pre_process_samples(samples):
			for sample in samples:
				delete_indices = []
				for idx in range(len(sample["add_to_buffer"])):
					if sample["add_to_buffer"][idx] == 0:
						delete_indices.append(idx)
				sample["obs"] = np.delete(sample["obs"], delete_indices, 0)
				sample["obs_"] = np.delete(sample["obs_"], delete_indices, 0)
				sample["ac"] = np.delete(sample["ac"], delete_indices, 0)
				sample["rewards"] = np.delete(sample["rewards"], delete_indices, 0)
				sample["add_to_buffer"] = np.delete(sample["add_to_buffer"], delete_indices, 0)
			return samples
		filepath = file= args.logdir + os.sep + 'replay_buffer_no_filtering.pkl'
		utils.to_pickle(obj=samples, file=filepath, verbose=True)

		if METHOD == "UARF":
			samples = pre_process_samples(samples)
		if i < self.ntrain_iters:
			l, trained, buffer_last_train = self.policy.train(
				[sample["obs"] for sample in samples],
				[sample["ac"] for sample in samples],
				[sample["rewards"] for sample in samples],
				[sample["obs_"] for sample in samples],
				[np.full_like(sample["rewards"], maneuver_index) for sample in samples] 
			)
			if trained:
				losses.append(l.copy())
		buffer_size = self.policy.train_in.shape[0]
		if args.SAVE:
			self.save(train_iteration = i)
		if METHOD == 'UARF':
			MAX_PREDICTION_DISTANCE = 1
			#MAX_PREDICTION_DISTANCE = max(int(1 / ep_perc_recalc) - 1, 1)
		else:
			MAX_PREDICTION_DISTANCE = 1



		train_time = time.time() - train_time
		losses_mean, losses_std = np.array(losses).mean(), np.array(losses).std()

		by_step_stats = {"Rewards" : ep_step_rewards,
						"Recalculated" : ep_recalculated,
						"Added to Buffer" : ep_added_to_buffer}
		by_step_df = pd.DataFrame(by_step_stats)
		by_step_df.to_csv(output_path_by_episode + f"/episode{i}.csv")

		ep_stat = {}
		ep_stat['total_steps'] = (i+2) * len(ep_step_rewards)
		ep_stat['ep_reward'] = ep_reward
		ep_stat['buffer_size'] = buffer_size
		ep_stat['losses_mean'] = losses_mean
		ep_stat['losses_std'] = losses_std
		ep_stat['perc_recalc'] = ep_perc_recalc
		ep_stat['run_wall_time'] = time.time() - episode_start_time
		ep_stat['training_time'] = train_time
		ep_stat['trained'] = int(trained)
		ep_stat['buffer_last_train'] = buffer_last_train
		ep_stat['percent_new_experience'] = (buffer_size - ep_stats[-1]['buffer_size'])/buffer_last_train if len(ep_stats) > 0 else 0

		ep_stats.append(ep_stat.copy())

		print("buffer size %d" %buffer_size , "Rewards obtained:", ep_reward, "loss: %.7f+-%.5f" % (losses_mean, losses_std),
			  "Collection time: %.2fs" % collection_time, "")

		ep_stats_ = pd.DataFrame(ep_stats)

		ep_stats_.to_csv(run_path + "Run%d.csv" % (0))
	print(run_path + "Run%d.csv" % (0))
	if args.SAVE:
		self.save()

	print('done')
