import random
import pickle_utils
import torch_utils
from algo.model import Model, DiscrimModel
from algo.verify import test_model, test_procgen
from algo.common.loss import cross_entropy_loss
import numpy as np
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import os
import torchvision.transforms as T
import torch
import torch.nn.functional as F
import cv2
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus
import torch.nn as nn


def sample_positive(obs, cam_model, n=2):
	cam = cam_model(obs)
	w = obs.shape[3]
	h = obs.shape[2]
	batch_shape = cam.shape[0]

	def get_pos_pos(img):
		max_val = float(np.max(img))
		min_val = float(np.min(img))

		pos_lst = np.where(img >= (min_val + 0.80 * (max_val - min_val)))
		return pos_lst

	all_positive_positions = []

	for i in range(batch_shape):
		cam_img = cam[i]
		positives = get_pos_pos(cam_img)

		num_positives = len(positives[0])
		pos_ind = np.random.randint(0, num_positives, n)

		pos_xys = []

		for j in range(n):
		    pos_xys.append([float(positives[0][pos_ind[j]]) / h,
		                    float(positives[1][pos_ind[j]]) / w])
		all_positive_positions.append(pos_xys)

	all_positive_positions = np.array(all_positive_positions)

	# [b, n, 2]
	return torch_utils.numpy_to_tensor(all_positive_positions)


class BehaviorCloning:
	def __init__(self, config):
		'''
			:param config:
		'''

		self.config = config
		self.model = Model(config).cuda()
		self.discrim_model = DiscrimModel(config).cuda()
        
        # The target model.
		self.cam_model1 = Model(config).cuda()

		if config.task == 'magical':
			self.cam1 = GradCAMPlusPlus(
			                            model=self.cam_model1,
			                            target_layers=[
			                                self.cam_model1.encoder.layer3[-1]
			                            ], 
			                            use_cuda=True
			                            )
		else:
			self.cam1 = GradCAMPlusPlus(
			                            model=self.cam_model1,
			                            target_layers=[
			                                self.cam_model1.encoder.layer2[-1],
			                                self.cam_model1.encoder.layer3[-1]
			                            ], 
			                            use_cuda=True
			                            )

		self.main_optimizer = optim.Adam(self.model.parameters(),
		                                 lr=config.lr, 
		                                 weight_decay=config.weight_decay)
		self.discrim_optimizer = optim.Adam(self.discrim_model.parameters(),
		                                    lr=config.lr)

		self.sample_n = config.sample_n
		self.writer = SummaryWriter(self.config.results_path)
		self.hist_max_mean_reward = 0
		self.a0 = config.a0

		self.all_past_models = []

	def append_model(self):
		self.all_past_models.append(self.model.state_dict())
		if len(self.all_past_models) > 10:
			self.all_past_models = self.all_past_models[1:]

	def get_past_model(self, index):
		if index < 0 and - index > len(self.all_past_models):
			index = 0

		return self.all_past_models[index]

	def prepare(self):
		'''
			Things to prepare before inference, 
			such as loading some reference demo.
			:return:
		'''
		self.model.eval()

	def train(self):
		'''
			Main training routine.
			:return:
		'''

		dataset = self.config.make_dataset()
		# reg_dataset = self.config.make_regularization_dataset()
		training_step = 0
		LENGTH = self.config.scale
		n_epoch_steps = self.config.n_epoch_steps

		should_terminal = False
		for epoch in range(self.config.n_epoch):
			print("Epoch {} starts.".format(epoch))
			self.model.train()
			for i in range(n_epoch_steps):
				obs, act = dataset.sample(self.config.batch_size)

				if self.config.task == 'magical':
					augmentation = T.RandomRotation(degrees=(0, 180))
					obs = augmentation(obs)

				act_predict, features = self.model(obs, ret_z=True)

				cam_positions_pos = sample_positive(obs, self.cam1, self.sample_n)
				
				# POINT Regularizer loss function.
				bc_network_loss_energy = F.logsigmoid(self.discrim_model(features, cam_positions_pos)).mean()
				
				# BC loss function.
				bc_network_loss_main = cross_entropy_loss(act_predict, act).mean()

                # Full loss function.
				bc_network_loss = bc_network_loss_main + self.a0 * bc_network_loss_energy

				self.main_optimizer.zero_grad()
				bc_network_loss.backward()
				# nn.utils.clip_grad_norm(self.model.parameters(), max_norm=10.0)
				self.main_optimizer.step()

				for _ in range(int(self.config.repeat)):
					# Optimize the discriminator.
					obs, act = dataset.sample(self.config.batch_size)
					if self.config.task == 'magical':
						augmentation = T.RandomRotation(degrees=(0, 180))
						obs = augmentation(obs)

					act_predict, features = self.model(obs, ret_z=True)
					cam_positions_pos = sample_positive(obs, self.cam1, self.sample_n)
					cam_positions_neg = cam_positions_pos[torch.randperm(cam_positions_pos.size(0))]
					discrim_loss = - F.logsigmoid(self.discrim_model(features, cam_positions_pos)).mean() \
								   - F.logsigmoid(- self.discrim_model(features, cam_positions_neg)).mean()

					self.discrim_optimizer.zero_grad()
					discrim_loss.backward()
					self.discrim_optimizer.step()

				self.writer.add_scalar("Training/BC_Policy_loss", bc_network_loss_main.item(), training_step)
				self.writer.add_scalar("Training/BC_Energy_loss", bc_network_loss_energy.item(), training_step)
				self.writer.add_scalar("Training/Disc_loss", discrim_loss.item(), training_step)

				if training_step % 10 == 0:
					print('Epoch {}|Step:{} Current Loss:{} BC_MAIN: {} BC_ENG: {} DISC: {}'.format(
					    epoch, training_step, bc_network_loss.item(), 
					    bc_network_loss_main.item(), bc_network_loss_energy.item(), discrim_loss.item()
					))

				training_step += 1

				if training_step % LENGTH == 0:
					self.append_model()
					cam_model1 = self.get_past_model(-2)
					self.cam_model1.load_state_dict(cam_model1)

			# Save model after each epoch.
			model_name = 'model_epoch_{}.pth'.format(epoch)
			model_save_path = os.path.join(self.config.results_path, model_name)
			self.save(model_save_path)

			# Do evaluation in the real environment.
			print("Start evaluation of epoch {}".format(epoch))
			all_reward = self.rollout_evaluation()
			mean_reward = sum(all_reward) / len(all_reward)
			max_reward = max(all_reward)

			debug_traj_name = 'debug_traj_ep_{}.mp4'.format(epoch)
			debug_traj_save_path = os.path.join(self.config.results_path, debug_traj_name)
			# pickle_utils.gsave_data(debug_trajs, debug_traj_save_path)
			# import video_utils
			# video_utils.write_videos(debug_trajs, debug_traj_save_path)

			self.writer.add_scalar("Test/Mean Performance", mean_reward, epoch)
			self.writer.add_scalar("Test/Max Performance", max_reward, epoch)
			self.hist_max_mean_reward = max([self.hist_max_mean_reward, mean_reward])
			print("Evaluation finished. Mean reward = {}; Max reward = {}; Hist Max = {}".format(
				mean_reward, max_reward, self.hist_max_mean_reward))

			if should_terminal:
				break

		return

	def rollout_evaluation(self):
		self.model.eval()
		if self.config.task == 'magical':
			all_reward = test_model(self, self.config.env_name, 25)
			return all_reward
		elif self.config.task == 'procgen':
			all_reward = test_procgen(self, self.config.env_name, torch.device('cuda'))
			return all_reward

	def test(self, checkpoint, n_test):
		self.load(checkpoint)
		self.model.eval()
		all_reward, _ = test_model(self, self.config.env_name, n_test)
		mean_reward = sum(all_reward) / len(all_reward)
		max_reward = max(all_reward)
		print("Test finished. Mean reward = {}; Max reward = {};".format(mean_reward, max_reward))

	def inference(self, obs):
		if self.config.task == 'magical':
			return self.inference_magical(obs)
		elif self.config.task == 'procgen':
			return self.inference_procgen(obs)

	def inference_magical(self, obs):
		'''
			Inference action given current observation.
			:param obs: Numpy Array, [H, W, C]; Un-normalized.
			:return: integer
		'''

		obs = self.config.proc_obs_np(obs)

		obs = torch_utils.numpy_to_tensor(obs)
		obs = obs.reshape(1, *obs.shape).permute(0, 3, 1, 2)
		action_prob = torch_utils.tensor_to_numpy(self.model(obs)[0])

		action_prob = action_prob / np.sum(action_prob)

		if self.config.deterministic:
			action_idx = np.argmax(action_prob)
		else:
			action_idx = np.random.choice(list(range(self.config.action_shape)), 1, p=action_prob)
			action_idx = int(action_idx)

		return action_idx

	def inference_procgen(self, obs):
		'''
			Inference action given current observation.
			:param obs: Numpy Array, [H, W, C]; Un-normalized.
			:return: integer
		'''
		action_prob = torch_utils.tensor_to_numpy(self.model(obs)[0])

		action_prob = action_prob / np.sum(action_prob)

		if self.config.deterministic:
			action_idx = np.argmax(action_prob)
		else:
			action_idx = np.random.choice(list(range(self.config.action_shape)), 1, p=action_prob)
			action_idx = int(action_idx)

		return np.array([action_idx], dtype=np.int32)

	def save(self, path):
		'''
			Save current model to the disk.
			:param path: str; path to save the model.
			:return:
		'''

		self.model.save(path)

	def load(self, path):
		'''
			Load model from the disk.
			:param path: str; path to load the model.
		:return:
		'''
		self.model.load(path)
