from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np
from scipy.io import savemat

from DotmapUtils import get_required_argument
from optimizers import CEMOptimizer

from tqdm import trange

import torch
import utils
TORCH_DEVICE = utils.TORCH_DEVICE



class Controller:
	def __init__(self, *args, **kwargs):
		"""Creates class instance.
		"""
		pass

	def train(self, obs_trajs, acs_trajs, rews_trajs):
		"""Trains this controller using lists of trajectories.
		"""
		raise NotImplementedError("Must be implemented in subclass.")

	def reset(self):
		"""Resets this controller.
		"""
		raise NotImplementedError("Must be implemented in subclass.")

	def act(self, obs, t, get_pred_cost=False):
		"""Performs an action.
		"""
		raise NotImplementedError("Must be implemented in subclass.")

	def dump_logs(self, primary_logdir, iter_logdir):
		"""Dumps logs into primary log directory and per-train iteration log directory.
		"""
		raise NotImplementedError("Must be implemented in subclass.")


def shuffle_rows(arr):
	idxs = np.argsort(np.random.uniform(size=arr.shape), axis=-1)
	return arr[np.arange(arr.shape[0])[:, None], idxs]


class MPC(Controller):
	optimizers = {"CEM": CEMOptimizer}

	def __init__(self, params):
		"""Creates class instance.

		Arguments:
			params
				.env (gym.env): Environment for which this controller will be used.
				.ac_ub (np.ndarray): (optional) An array of action upper bounds.
					Defaults to environment action upper bounds.
				.ac_lb (np.ndarray): (optional) An array of action lower bounds.
					Defaults to environment action lower bounds.
				.per (int): (optional) Determines how often the action sequence will be optimized.
					Defaults to 1 (reoptimizes at every call to act()).
				.prop_cfg
					.model_init_cfg (DotMap): A DotMap of initialization parameters for the model.
						.model_constructor (func): A function which constructs an instance of this
							model, given model_init_cfg.
					.model_train_cfg (dict): (optional) A DotMap of training parameters that will be passed
						into the model every time is is trained. Defaults to an empty dict.
					.model_pretrained (bool): (optional) If True, assumes that the model
						has been trained upon construction.
					.mode (str): Propagation method. Choose between [E, DS, TSinf, TS1, MM].
						See https://arxiv.org/abs/1805.12114 for details.
					.npart (int): Number of particles used for DS, TSinf, TS1, and MM propagation methods. This is not the population in the CEM!
					.ign_var (bool): (optional) Determines whether or not variance output of the model
						will be ignored. Defaults to False unless deterministic propagation is being used.
					.obs_preproc (func): (optional) A function which modifies observations (in a 2D matrix)
						before they are passed into the model. Defaults to lambda obs: obs.
						Note: Must be able to process both NumPy and Tensorflow arrays.
					.obs_postproc (func): (optional) A function which returns vectors calculated from
						the previous observations and model predictions, which will then be passed into
						the provided cost function on observations. Defaults to lambda obs, model_out: model_out.
						Note: Must be able to process both NumPy and Tensorflow arrays.
					.obs_postproc2 (func): (optional) A function which takes the vectors returned by
						obs_postproc and (possibly) modifies it into the predicted observations for the
						next time step. Defaults to lambda obs: obs.
						Note: Must be able to process both NumPy and Tensorflow arrays.
					.targ_proc (func): (optional) A function which takes current observations and next
						observations and returns the array of targets (so that the model learns the mapping
						obs -> targ_proc(obs, next_obs)). Defaults to lambda obs, next_obs: next_obs.
						Note: Only needs to process NumPy arrays.
				.opt_cfg
					.mode (str): Internal optimizer that will be used. Choose between [CEM].
					.cfg (DotMap): A map of optimizer initializer parameters.
					.plan_hor (int): The planning horizon that will be used in optimization.
					.obs_cost_fn (func): A function which computes the cost of every observation
						in a 2D matrix.
						Note: Must be able to process both NumPy and Tensorflow arrays.
					.ac_cost_fn (func): A function which computes the cost of every action
						in a 2D matrix.
				.log_cfg
					.save_all_models (bool): (optional) If True, saves models at every iteration.
						Defaults to False (only most recent model is saved).
						Warning: Can be very memory-intensive.
					.log_traj_preds (bool): (optional) If True, saves the mean and variance of predicted
						particle trajectories. Defaults to False.
					.log_particles (bool) (optional) If True, saves all predicted particles trajectories.
						Defaults to False. Note: Takes precedence over log_traj_preds.
						Warning: Can be very memory-intensive
		"""
		super().__init__(params)
		self.dO, self.dU = params.env.observation_space.shape[0], params.env.action_space.shape[0]
		self.ac_ub, self.ac_lb = params.env.action_space.high, params.env.action_space.low
		self.ac_ub = np.minimum(self.ac_ub, params.get("ac_ub", self.ac_ub))
		self.ac_lb = np.maximum(self.ac_lb, params.get("ac_lb", self.ac_lb))
		self.update_fns = params.get("update_fns", [])
		self.per = params.get("per", 1)

		self.model_init_cig = params.prop_cfg.get("model_init_cfg", {})
		self.model_train_cfg = params.prop_cfg.get("model_train_cfg", {})
		self.prop_mode = get_required_argument(params.prop_cfg, "mode", "Must provide propagation method.")
		self.npart = get_required_argument(params.prop_cfg, "npart", "Must provide number of particles.")
		self.ign_var = params.prop_cfg.get("ign_var", False) or self.prop_mode == "E"

		self.obs_preproc = params.prop_cfg.get("obs_preproc", lambda obs: obs)
		self.obs_postproc = params.prop_cfg.get("obs_postproc", lambda obs, model_out: model_out)
		self.obs_postproc2 = params.prop_cfg.get("obs_postproc2", lambda next_obs: next_obs)
		self.targ_proc = params.prop_cfg.get("targ_proc", lambda obs, next_obs: next_obs)

		self.opt_mode = get_required_argument(params.opt_cfg, "mode", "Must provide optimization method.")
		self.plan_hor = get_required_argument(params.opt_cfg, "plan_hor", "Must provide planning horizon.")
		self.obs_cost_fn = get_required_argument(params.opt_cfg, "obs_cost_fn", "Must provide cost on observations.")
		self.ac_cost_fn = get_required_argument(params.opt_cfg, "ac_cost_fn", "Must provide cost on actions.")

		self.save_all_models = params.log_cfg.get("save_all_models", False)
		self.log_traj_preds = params.log_cfg.get("log_traj_preds", False)
		self.log_particles = params.log_cfg.get("log_particles", False)

		self.new_data_counter = 0
		self.cold_start_steps = 0
		self.max_buffer_length = 2500
		self.buffer_last_train = 0
		self.new_data_train_threshold = 0.01
		self.method = "UARF"

		# Perform argument checks
		assert self.opt_mode == 'CEM'
		assert self.prop_mode == 'TSinf', 'only TSinf propagation mode is supported'
		assert self.npart % self.model_init_cig.num_nets == 0, "Number of particles must be a multiple of the ensemble size."

		# Create action sequence optimizer
		opt_cfg = params.opt_cfg.get("cfg", {})
		self.optimizer = CEMOptimizer(
			sol_dim=self.plan_hor * self.dU,
			lower_bound=np.tile(self.ac_lb, [self.plan_hor]),
			upper_bound=np.tile(self.ac_ub, [self.plan_hor]),
			cost_function=self._compile_cost,
			**opt_cfg
		)

		# Controller state variables
		self.has_been_trained = params.prop_cfg.get("model_pretrained", False)
		self.ac_buf = np.array([]).reshape(0, self.dU)
		self.prev_sol = np.tile((self.ac_lb + self.ac_ub) / 2, [self.plan_hor])
		self.init_var = np.tile(np.square(self.ac_ub - self.ac_lb) / 16, [self.plan_hor])
		self.train_in = np.array([]).reshape(0, self.dU + self.obs_preproc(np.zeros([1, self.dO])).shape[-1])
		self.train_targs = np.array([]).reshape(
			0, self.targ_proc(np.zeros([1, self.dO]), np.zeros([1, self.dO])).shape[-1]
		)
		self.raw_observations = None
		self.raw_actions = None

		print("Created an MPC controller, prop mode %s, %d particles. " % (self.prop_mode, self.npart) +
			  ("Ignoring variance." if self.ign_var else ""))

		if self.save_all_models:
			print("Controller will save all models. (Note: This may be memory-intensive.")
		if self.log_particles:
			print("Controller is logging particle predictions (Note: This may be memory-intensive).")
			self.pred_particles = []
		elif self.log_traj_preds:
			print("Controller is logging trajectory prediction statistics (mean+var).")
			self.pred_means, self.pred_vars = [], []
		else:
			print("Trajectory prediction logging is disabled.")

		# Set up pytorch model
		self.model = get_required_argument(
			params.prop_cfg.model_init_cfg, "model_constructor", "Must provide a model constructor."
		)(params.prop_cfg.model_init_cfg)

	def train(self, obs_trajs, acs_trajs, rews_trajs, next_obs_trajs, maneuvers):
		"""Trains the internal model of this controller. Once trained,
		this controller switches from applying random actions to using MPC.

		Arguments:
			obs_trajs: A list of observation matrices, observations in rows.
			acs_trajs: A list of action matrices, actions in rows.
			rews_trajs: A list of reward arrays.

		Returns: None.
		"""
		# Construct new training points and add to training set
		new_train_in, new_train_targs = [], []
		skipped_training = True
		for obs, acs, obs_ in zip(obs_trajs, acs_trajs, next_obs_trajs):
			new_train_in.append(np.concatenate([self.obs_preproc(obs), acs], axis=-1))
			new_train_targs.append(self.targ_proc(obs, obs_))
		self.new_data_counter += len(new_train_in[0])
		self.train_in = np.concatenate([self.train_in] + new_train_in, axis=0)
		self.train_targs = np.concatenate([self.train_targs] + new_train_targs, axis=0)
		if self.raw_observations is not None :
			self.raw_observations = np.concatenate([self.raw_observations] + obs_trajs)
			self.raw_actions = np.concatenate([self.raw_actions] +  acs_trajs)
			self.maneuvers = np.concatenate([self.maneuvers] +  maneuvers)
		else:
			self.raw_observations = np.array(obs_trajs).squeeze()
			self.raw_actions = np.array(acs_trajs).squeeze(0)
			self.maneuvers = np.array(maneuvers).squeeze()

		if self.max_buffer_length is not None and len(self.train_in) > self.max_buffer_length:
			self.train_in = self.train_in[len(self.train_in) - self.max_buffer_length:]
			self.train_targs = self.train_targs[len(self.train_targs) - self.max_buffer_length:]
			self.raw_observations = self.raw_observations[len(self.raw_observations) - self.max_buffer_length:]
			self.raw_actions = self.raw_actions[len(self.raw_actions) - self.max_buffer_length:]
			self.maneuvers = self.maneuvers[len(self.maneuvers) - self.max_buffer_length:]


		
		# Train the model
		if not self.has_been_trained:
			self.buffer_last_train = len(self.train_in)
		self.has_been_trained = True

		# Train the pytorch model
		self.model.fit_input_stats(self.train_in)

		idxs = np.random.randint(self.train_in.shape[0], size=[self.model.num_nets, self.train_in.shape[0]])

		epochs = self.model_train_cfg['epochs']

		# TODO: double-check the batch_size for all env is the same
		batch_size = 32

		epoch_range = trange(epochs, unit="epoch(s)", desc="Network training")
		num_batch = int(np.ceil(idxs.shape[-1] / batch_size))
		if self.method != "UARF" or self.new_data_counter >= self.buffer_last_train * self.new_data_train_threshold or len(self.train_in) < self.cold_start_steps:
			old_buffer_last_train = self.buffer_last_train
			self.buffer_last_train = len(self.train_in)
			self.new_data_counter = 0
			skipped_training = False
			for _ in epoch_range:
				for batch_num in range(num_batch):
					batch_idxs = idxs[:, batch_num * batch_size : (batch_num + 1) * batch_size]

					loss = 0.01 * (self.model.max_logvar.sum() - self.model.min_logvar.sum())
					loss += self.model.compute_decays()

					# TODO: move all training data to GPU before hand
					train_in = torch.from_numpy(self.train_in[batch_idxs]).to(TORCH_DEVICE).float()
					train_targ = torch.from_numpy(self.train_targs[batch_idxs]).to(TORCH_DEVICE).float()

					mean, logvar = self.model(train_in, ret_logvar=True)
					inv_var = torch.exp(-logvar)

					train_losses = ((mean - train_targ) ** 2) * inv_var + logvar
					train_losses = train_losses.mean(-1).mean(-1).sum()
					# Only taking mean over the last 2 dimensions
					# The first dimension corresponds to each model in the ensemble

					loss += train_losses

					self.model.optim.zero_grad()
					loss.backward()
					self.model.optim.step()

			idxs = shuffle_rows(idxs)

			val_in = torch.from_numpy(self.train_in[idxs[:,:5000]]).to(TORCH_DEVICE).float()
			val_targ = torch.from_numpy(self.train_targs[idxs[:,:5000]]).to(TORCH_DEVICE).float()

			mean, _ = self.model(val_in)
			mse_losses = ((mean - val_targ) ** 2).mean(-1).mean(-1)

			epoch_range.set_postfix({
				"Training loss(es)": mse_losses.detach().cpu().numpy()
			})
			return mse_losses.detach().cpu().numpy(), True, old_buffer_last_train
		old_buffer_last_train = self.buffer_last_train
		return np.array([0.0]), False, old_buffer_last_train
	def reset(self):
		"""Resets this controller (clears previous solution, calls all update functions).

		Returns: None
		"""
		self.prev_sol = np.tile((self.ac_lb + self.ac_ub) / 2, [self.plan_hor])
		self.optimizer.reset()

		for update_fn in self.update_fns:
			update_fn()

	def act(self, obs, t, get_pred_cost=False):
		"""Returns the action that this controller would take at time t given observation obs.

		Arguments:
			obs: The current observation
			t: The current timestep
			get_pred_cost: If True, returns the predicted cost for the action sequence found by
				the internal optimizer.

		Returns: An action (and possibly the predicted cost)
		"""
		if not self.has_been_trained:
			return np.random.uniform(self.ac_lb, self.ac_ub, self.ac_lb.shape)
		if self.ac_buf.shape[0] > 0:
			# this is done for skipping steps
			# if there is something in ac_bub (actions buffer) the we simply remove from the list and return the action
			action, self.ac_buf = self.ac_buf[0], self.ac_buf[1:]
			return action

		self.sy_cur_obs = obs

		# returns array (horizon,) -> with the best actions found
		#   start with the previous means (solutions) and always the initial variance
		soln = self.optimizer.obtain_solution(self.prev_sol, self.init_var)
		'''
		if get_traj:
			obs, var = get_traj(init_obs, soln)
			traj = sorn, obs, var (25, 26, 26)
		'''
		# dU -> actions dimension
		#
		self.prev_sol_full = np.copy(soln)
		self.prev_sol = np.concatenate([np.copy(soln)[self.per * self.dU:], np.zeros(self.per * self.dU)]) # set last to zero
		self.ac_buf = soln[:self.per * self.dU].reshape(-1, self.dU)  # gives back one action (1,1)

		return self.act(obs, t)

	def dump_logs(self, primary_logdir, iter_logdir):
		"""Saves logs to either a primary log directory or another iteration-specific directory.
		See __init__ documentation to see what is being logged.

		Arguments:
			primary_logdir (str): A directory path. This controller assumes that this directory
				does not change every iteration.
			iter_logdir (str): A directory path. This controller assumes that this directory
				changes every time dump_logs is called.

		Returns: None
		"""
		# TODO: implement saving model for pytorch
		# self.model.save(iter_logdir if self.save_all_models else primary_logdir)
		if self.log_particles:
			savemat(os.path.join(iter_logdir, "predictions.mat"), {"predictions": self.pred_particles})
			self.pred_particles = []
		elif self.log_traj_preds:
			savemat(
				os.path.join(iter_logdir, "predictions.mat"),
				{"means": self.pred_means, "vars": self.pred_vars}
			)
			self.pred_means, self.pred_vars = [], []

	@torch.no_grad()
	def _compile_cost(self, ac_seqs, return_seq=False):
		'''
		Starts always with self.sy_cur_obs

		'''
		# ac_seqs = np.array([np.arange(25)*1,
		#	   np.arange(25)*10,
		#	   np.arange(25)*0,
		#	  ]
		#	 )

		nopt = ac_seqs.shape[0]

		ac_seqs = torch.from_numpy(ac_seqs).float().to(TORCH_DEVICE)

		# Reshape ac_seqs so that it's amenable to parallel compute
		# Before, ac seqs has dimension (400, 25) which are pop size and sol dim coming from CEM
		ac_seqs = ac_seqs.view(-1, self.plan_hor, self.dU)
		#  After, ac seqs has dimension (400, 25, 1)

		transposed = ac_seqs.transpose(0, 1)
		# Then, (25, 400, 1)

		expanded = transposed[:, :, None]
		# Then, (25, 400, 1, 1)

		tiled = expanded.expand(-1, -1, self.npart, -1)
		# Then, (25, 400, 20, 1)

		ac_seqs = tiled.contiguous().view(self.plan_hor, -1, self.dU)
		# Then, (25, 8000, 1)

		# Expand current observation
		cur_obs = torch.from_numpy(self.sy_cur_obs).float().to(TORCH_DEVICE)
		cur_obs = cur_obs[None]
		cur_obs = cur_obs.expand(nopt * self.npart, -1)

		costs = torch.zeros(nopt, self.npart, device=TORCH_DEVICE)

		if return_seq:
			traj_cur_obs = []
			traj_cur_acs = []
			traj_next_obs = []
			traj_next_cost = []

		for t in range(self.plan_hor):
			cur_acs = ac_seqs[t]

			if return_seq:
				traj_cur_obs.append(cur_obs.detach().cpu().numpy())
				traj_cur_acs.append(cur_acs.detach().cpu().numpy())

			next_obs = self._predict_next_obs(cur_obs, cur_acs)

			if return_seq:
				traj_next_obs.append(next_obs.detach().cpu().numpy())

			cost = self.obs_cost_fn(next_obs) + self.ac_cost_fn(cur_acs)
			if return_seq:
				traj_next_cost.append(cost.detach().cpu().numpy())

			cost = cost.view(-1, self.npart)
			costs += cost
			cur_obs = self.obs_postproc2(next_obs)

		# Replace nan with high cost
		costs[costs != costs] = 1e6

		# return the average of the 20 particles for each element in the population
		# 400, 20 -> returns 400

		if return_seq:
			ret = {'traj_cur_obs':traj_cur_obs,
				   'traj_cur_acs':traj_cur_acs,
				   'traj_next_obs':traj_next_obs,
				   'traj_next_cost':traj_next_cost}
			return costs.mean(dim=1).detach().cpu().numpy(), ret
		else:
			return costs.mean(dim=1).detach().cpu().numpy()

	def _predict_next_obs(self, obs, acs):
		proc_obs = self.obs_preproc(obs)

		assert self.prop_mode == 'TSinf'
		'''
		We had obs = (particles x population CEM, obs_space) [20 * 400, 5]. where 400 have the same initial condition and repeated 20 times
		Now we expand from (8000,5) (ens x particles, obs_space) -> (5,1600,5) (ens, particles, obs_space)
		'''
		proc_obs = self._expand_to_ts_format(proc_obs)
		'''
		acs (8000,1) to (5,1600,1)
		'''
		acs = self._expand_to_ts_format(acs)

		inputs = torch.cat((proc_obs, acs), dim=-1)

		# calculate gaussaian regression only
		mean, var = self.model(inputs)  # -> (5,1600,4), (5,1600,4) (ens, .., delta obs)

		predictions = mean + torch.randn_like(mean, device=TORCH_DEVICE) * var.sqrt()

		# TS Optimization: Remove additional dimension
		predictions = self._flatten_to_matrix(predictions)

		return self.obs_postproc(obs, predictions)

	def _expand_to_ts_format(self, mat):
		dim = mat.shape[-1]

		# Before, [10, 5] in case of proc_obs
		reshaped = mat.view(-1, self.model.num_nets, self.npart // self.model.num_nets, dim)
		# After, [2, 5, 1, 5]

		transposed = reshaped.transpose(0, 1)
		# After, [5, 2, 1, 5]

		reshaped = transposed.contiguous().view(self.model.num_nets, -1, dim)
		# After. [5, 2, 5]

		return reshaped

	def _flatten_to_matrix(self, ts_fmt_arr):
		dim = ts_fmt_arr.shape[-1]

		reshaped = ts_fmt_arr.view(self.model.num_nets, -1, self.npart // self.model.num_nets, dim)

		transposed = reshaped.transpose(0, 1)

		reshaped = transposed.contiguous().view(-1, dim)

		return reshaped
