import numpy as np
import scipy.stats
from scipy.stats import norm
import os
from matplotlib import pyplot as plt
from matplotlib import rc
from itertools import cycle

from datetime import datetime
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.cluster import AgglomerativeClustering
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
import time
import gc


def normalize(x):
	std = np.std(x, axis=0)
	if np.isscalar(std):
		if std < 1e-4:
			std = 1
	else:
		std[std < 1e-4] = 1
	x = (x - np.mean(x, axis=0)) / std
	return x


REMOVE_WORST_FROM_REPLAY_MEMORY = False
ELITE_RATIO = 0.25


class MCGS:
	def __init__(self, config, model):
		self.config = config

		self.env = model

		self.action_space = model.action_space
		self.action_dim = self.action_space.shape[0]
		self.observation_space = model.observation_space
		self.state_dim = model.observation_space.low.shape[0]

		# set hyperparams
		self.min_graph_length = config.min_graph_length
		self.rollout_length = config.rollout_length
		self.max_n_exps = config.simulation_budget // 5
		self.min_n_data_per_q_node = max(25, self.max_n_exps // 10)
		self.clustering_alg = config.clustering_alg
		self.simulation_budget = config.simulation_budget

		self.optimal_prob = config.optimal_prob
		self.optimal_n_top = config.optimal_n_top
		self.optimal_range = config.optimal_range

		global ELITE_RATIO
		ELITE_RATIO = config.elite_ratio

		self.layers = None
		self.best_action = None
		self.best_rew = None
		self.exp_owner = None
		self.exp_owner = None
		self.states = None
		self.states2 = None
		self.actions = None
		self.rewards = None
		self.n_exps = None
		self.postpone_clustering = None

		self.time_budget_total = 0
		self.time_budget_env_step = 0
		self.time_budget_clustering = 0
		self.time_budget_update_nodes = 0
		self.time_action_bandit = 0
		self.time_q_bandit = 0

		self.stats_n_clustering_tryouts = 0
		self.stats_n_clustering_success = 0

		self.two_d_nav_plot_time_tag = datetime.now().strftime("%Y-%m-%d--%H-%M-%S")[:-3]
		self.two_d_nav_plot_counter = 0

		self.reset()

	def reset(self):
		self.env.reset()
		self.env.save_checkpoint()

	def add_node_to_layer(self, l):
		if self.layers is None:
			self.layers = []
		while len(self.layers) < l + 1:
			self.layers.append({
				'nodes': []
				, 'exp_owner': -np.ones(self.max_n_exps, dtype=np.int64)
				, 'states': np.zeros((self.max_n_exps, self.state_dim))
				, 'actions': np.zeros((self.max_n_exps, self.action_dim))
				, 'states2': np.zeros((self.max_n_exps, self.state_dim))
				, 'rewards': np.zeros(self.max_n_exps)
				, 'n_exps': 0
				, 'postpone_clustering': 0
			})
		self.layers[l]['nodes'].append(Q_Node(action_space=self.env.action_space, state_dim=self.state_dim
		                                      , min_n_data_per_q_node=self.min_n_data_per_q_node))

	def add_exp_to_replay_memory(self, l, q, exp, rew):
		idx = self.layers[l]['n_exps'] % self.max_n_exps
		if REMOVE_WORST_FROM_REPLAY_MEMORY and self.layers[l]['n_exps'] >= self.max_n_exps:
			idx = np.argmin(self.layers[l]['rewards'])
		self.layers[l]['exp_owner'][idx] = q
		self.layers[l]['states'][idx, :] = exp[0][:]
		self.layers[l]['actions'][idx, :] = exp[1][:]
		self.layers[l]['states2'][idx, :] = exp[2][:]
		self.layers[l]['rewards'][idx] = rew
		self.layers[l]['n_exps'] += 1

	def act(self, observation=None):
		time_total_start = time.perf_counter()

		self.env.save_checkpoint()

		reuse_old_graph = False

		if not reuse_old_graph or self.layers is None:
			self.layers = None
			self.add_node_to_layer(0)

		self.best_action = (self.env.action_space.low + self.env.action_space.high) / 2
		self.best_rew = -np.inf

		timesteps = 0
		while timesteps < self.simulation_budget:
			# Graph policy
			exps, done, ep_ret, traj_len = self.graph_policy(self.env, observation)
			if not done:
				# Default policy
				ep_ret, traj_len = self.default_policy(self.env, ep_ret, traj_len)
	
			# Backpropagation
			self.backpropagate(exps, ep_ret)
			timesteps += traj_len

			# Load the saved model state
			self.env.load_checkpoint()

		action = np.copy(self.best_action)

		# Important!! update the model to keep it in sync with env.
		s2, _, _, info = self.env.step(action)
		if reuse_old_graph:
			q = self.q_bandit(1, s2)
			del self.layers[0]
			del self.layers[0]['nodes'][:q]
			del self.layers[0]['nodes'][1:]
			self.layers[0]['exp_owner'] *= 0

		self.time_budget_total += time.perf_counter() - time_total_start

		return action

	def graph_policy(self, env, cur_obs):
		q = 0
		exps = []
		s1 = cur_obs
		sum_r = 0
		traj_len = 0
		done = False
		n_layers = len(self.layers)
		l = 0
		is_optimal = np.random.rand() < self.optimal_prob
		while l < n_layers:
			time_action_bandit_start = time.perf_counter()
			a = self.layers[l]['nodes'][q].action_bandit(is_optimal, self.optimal_n_top, self.optimal_range)
			self.time_action_bandit += time.perf_counter() - time_action_bandit_start
			time_step_start = time.perf_counter()
			s2, r, done, info = env.step(a)
			self.time_budget_env_step += time.perf_counter() - time_step_start
			exps.append({'q': q, 'exp': [s1, a, s2]})
			sum_r += r
			traj_len += 1
			if not done and l + 1 == n_layers and (
					self.layers[l]['n_exps'] > self.min_n_data_per_q_node or n_layers < self.min_graph_length
			):
				self.add_node_to_layer(l + 1)
				n_layers += 1
			if done or l + 1 == n_layers:
				break
			time_q_bandit_start = time.perf_counter()
			q = self.q_bandit(l + 1, s2)
			self.time_q_bandit += time.perf_counter() - time_q_bandit_start
			s1 = s2
			l += 1
		return exps, done, sum_r, traj_len

	def q_bandit(self, l, s):
		n_q = len(self.layers[l]['nodes'])
		if n_q == 1:
			return 0
		max_score = -np.inf
		max_score_i = -1
		for q in range(n_q):
			pdf = self.layers[l]['nodes'][q].pdf.pdf(s)
			score = np.mean(pdf)
			score += np.random.random_sample() * 0.01  # Break ties randomly
			if score > max_score:
				max_score = score
				max_score_i = q
		return max_score_i

	def default_policy(self, env, ep_ret, traj_len):
		n_layers = len(self.layers)
		assert traj_len == n_layers
		for t in range(self.rollout_length):
			time_step_start = time.perf_counter()
			ac = np.random.uniform(low=self.action_space.low, high=self.action_space.high, size=self.action_space.shape)
			ob, r, done, info = env.step(ac)
			self.time_budget_env_step += time.perf_counter() - time_step_start
			ep_ret += r
			traj_len += 1
			if done:
				break
		return ep_ret, traj_len

	@ignore_warnings(category=ConvergenceWarning)
	def backpropagate(self, exps, rew):
		# Update the best found action if needed
		if rew > self.best_rew:
			self.best_rew = rew
			self.best_action[:] = exps[0]['exp'][1][:]

		for l in range(len(exps)):
			q = exps[l]['q']

			self.add_exp_to_replay_memory(l, q, exps[l]['exp'], rew)

			if self.layers[l]['postpone_clustering'] > 0:
				self.layers[l]['postpone_clustering'] -= 1
			n_exps = min(self.layers[l]['n_exps'], self.max_n_exps)
			if l > 0 and n_exps >= 60:
				n_exps = n_exps

			min_n_data_per_q_node = self.min_n_data_per_q_node
			if 25 <= n_exps < min_n_data_per_q_node * 2:
				min_n_data_per_q_node = max(10, n_exps // 4)

			desired_n_clusters = (n_exps - min_n_data_per_q_node // 3) // min_n_data_per_q_node
			desired_n_clusters = min(len(self.layers[l]['nodes']) + 1, desired_n_clusters)
			clustered = True
			if l > 0 and len(self.layers[l]['nodes']) < desired_n_clusters and self.layers[l]['postpone_clustering'] == 0:
				assert len(self.layers[l]['nodes']) + 1 == desired_n_clusters
				time_cluster_start = time.perf_counter()
				self.stats_n_clustering_tryouts += 1
				# Try to make a new cluster
				cur_layer_data = self.layers[l]['states'][:n_exps, :]
				
				if self.clustering_alg == "kmeans":
					clusters_idx = KMeans(n_clusters=desired_n_clusters, random_state=int(time.time())).fit(
						cur_layer_data).labels_
				elif self.clustering_alg == "agglomerative":
					clusters_idx = AgglomerativeClustering(n_clusters=desired_n_clusters).fit(cur_layer_data).labels_
				else:
					assert self.clustering_alg == "gmm"
					clusters_idx = GaussianMixture(
						n_components=desired_n_clusters, covariance_type="full", random_state=int(time.time())).fit\
							(cur_layer_data).predict(cur_layer_data)
				for c in range(desired_n_clusters):
					if np.sum(clusters_idx == c) < min_n_data_per_q_node / 2:
						clustered = False
						break
				self.time_budget_clustering += time.perf_counter() - time_cluster_start
				if clustered:
					self.stats_n_clustering_success += 1
					self.add_node_to_layer(l)
					self.layers[l]['exp_owner'][:n_exps] = clusters_idx
					assert len(self.layers[l]['nodes']) == desired_n_clusters
					time_update_nodes_start = time.perf_counter()
					for q in range(desired_n_clusters):
						indices = self.layers[l]['exp_owner'] == q
						self.layers[l]['nodes'][q].update(self.layers[l]['states'][indices, :], self.layers[l]['actions'][indices, :], self.layers[l]['rewards'][indices])
					self.time_budget_update_nodes += time.perf_counter() - time_update_nodes_start
				else:
					# Postpone the clustering for better performance
					self.layers[l]['postpone_clustering'] = min_n_data_per_q_node // 2
			else:
				clustered = False
			if not clustered:
				# No new clustering needed. Update the Q node regularly.
				indices = self.layers[l]['exp_owner'] == q
				if np.sum(indices) >= min_n_data_per_q_node // 2:
					time_update_nodes_start = time.perf_counter()
					self.layers[l]['nodes'][q].update(self.layers[l]['states'][indices, :], self.layers[l]['actions'][indices, :], self.layers[l]['rewards'][indices])
					self.time_budget_update_nodes += time.perf_counter() - time_update_nodes_start

	def report_time_budget(self):
		if self.time_budget_total == 0:
			return
		print('MCGS Stats Report:')
		print('\tTotal search time: %.2f' % self.time_budget_total)
		print('\t\tEnv step ratio: %d%%' % (100 * self.time_budget_env_step / self.time_budget_total))
		print('\t\tClustering ratio: %d%%' % (100 * self.time_budget_clustering / self.time_budget_total))
		print('\t\tNode update ratio: %d%%' % (100 * self.time_budget_update_nodes / self.time_budget_total))
		print('\t\tAction bandit ratio: %d%%' % (100 * self.time_action_bandit / self.time_budget_total))
		print('\t\tQ bandit ratio: %d%%' % (100 * self.time_q_bandit / self.time_budget_total))

		print('\tClustering success rate: %d%% (%d out of %d)' % (
				100 * self.stats_n_clustering_success / max(self.stats_n_clustering_tryouts, 1)
				, self.stats_n_clustering_success, self.stats_n_clustering_tryouts))


class Q_Node:
	def __init__(self, action_space, state_dim, min_n_data_per_q_node):
		self.action_dim = action_space.low.shape[0]
		self.action_min = action_space.low
		self.action_max = action_space.high

		self.state_dim = state_dim

		self.state_mean = np.zeros(self.state_dim)
		self.state_std = 0.1 * np.ones(self.state_dim)
		self.pdf = scipy.stats.norm(self.state_mean, self.state_std)

		self.action_mean = 0.5 * (self.action_min + self.action_max)
		self.action_sd = 0.5 * (self.action_max - self.action_min)

		self.n_data = 0

		self.min_n_data_per_q_node = min_n_data_per_q_node

		self.actions = None
		self.rewards = None

		self.gpr = None

	def action_bandit(self, is_optimal, optimal_n_top, optimal_range):

		if self.rewards is None or len(self.rewards) < self.min_n_data_per_q_node / 2:
			return self.action_min + np.random.rand(self.action_dim) * (self.action_max - self.action_min)

		if is_optimal:
			top_indices = np.argpartition(-self.rewards, optimal_n_top)
			selected_idx = top_indices[np.random.randint(0, optimal_n_top)]
			action = np.random.normal(self.actions[selected_idx, :], optimal_range * (self.action_max - self.action_min))
			# action = self.actions[selected_idx, :] + optimal_range * (2 * np.random.rand(self.action_dim) - 1) * (self.action_max - self.action_min)
			return action

		action = np.random.normal(self.action_mean, self.action_sd)#, size=(n_samples, self.action_dim))
		return action

	def update(self, states, actions, rewards):
		self.n_data = states.shape[0]

		self.actions = actions
		self.rewards = rewards

		self.state_mean = np.mean(states, axis=0)
		self.state_std = np.std(states, axis=0)
		self.state_std = np.clip(self.state_std, 0.1, 0.5)

		self.pdf = scipy.stats.norm(self.state_mean, self.state_std)

		self.action_mean = None
		weights = np.copy(rewards)
		M, m = np.max(weights), np.min(weights)
		if M - m > 0.01:
			weights = (weights - m) / (M - m)

			elite_num = max(5, int(weights.shape[0] * ELITE_RATIO))
			if elite_num < weights.shape[0]:
				elite_idx = np.argpartition(rewards, elite_num, axis=None)[-elite_num:]
				if np.sum(weights[elite_idx]) > 1e-2:
					self.action_mean = np.average(actions[elite_idx, :], axis=0, weights=weights[elite_idx])
		if self.action_mean is None:
			self.action_mean = np.mean(actions, axis=0)

		sd = 0.15 + 0.35 * np.exp(-0.0005 * np.power(self.n_data, 2))
		self.action_sd = sd * (self.action_max - self.action_min)
