import os
import numpy as np

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from global_vars import PIXEL_SIZE, CAMERA_CONFIG, BOUNDS, IN_SHAPE

from tasks import cameras
import utils.transporter_utils as utils
from models.core.attention import Attention
from models.core.transport import Transport
from models.streams.two_stream_attention import TwoStreamAttention
from models.streams.two_stream_transport import TwoStreamTransport

from models.streams.two_stream_attention import TwoStreamAttentionLat
from models.streams.two_stream_transport import TwoStreamTransportLat

class TransporterAgent(LightningModule):
	def __init__(self, name, cfg, train_ds, test_ds):
		super().__init__()
		utils.set_seed(0)

		self.device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # this is bad for PL :(
		self.name = name
		self.cfg = cfg
		self.train_ds = train_ds
		self.test_ds = test_ds

		self.name = name
		self.task = cfg['train']['task']
		self.total_steps = 0
		self.crop_size = 64
		self.n_rotations = cfg['train']['n_rotations']

		self.pix_size = PIXEL_SIZE
		self.in_shape = IN_SHAPE
		self.cam_config = CAMERA_CONFIG
		self.bounds = BOUNDS

		self.val_repeats = cfg['train']['val_repeats']
		self.save_steps = cfg['train']['save_steps']

		self._build_model()
		self._optimizers = {
			'attn': torch.optim.Adam(self.attention.parameters(), lr=self.cfg['train']['lr']),
			'trans': torch.optim.Adam(self.transport.parameters(), lr=self.cfg['train']['lr'])
		}
		self.cnt = 0
		print("Agent: {}, Logging: {}".format(name, cfg['train']['log']))

	def _build_model(self):
		self.attention = None
		self.transport = None
		raise NotImplementedError()

	def forward(self, x):
		raise NotImplementedError()

	def cross_entropy_with_logits(self, pred, labels, reduction='mean'):
		# Lucas found that both sum and mean work equally well
		x = (-labels * F.log_softmax(pred, -1))
		if reduction == 'sum':
			return x.sum()
		elif reduction == 'mean':
			return x.mean()
		else:
			raise NotImplementedError()

	def attn_forward(self, inp, softmax=True):
		inp_img = inp['inp_img']

		output = self.attention.forward(inp_img, softmax=softmax)
		return output

	def attn_training_step(self, frame, backprop=True, compute_err=False):
		inp_img = frame['img']
		p0, p0_theta = frame['p0'], frame['p0_theta']

		inp = {'inp_img': inp_img}
		out = self.attn_forward(inp, softmax=False)
		return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta)

	def attn_criterion(self, backprop, compute_err, inp, out, p, theta):
		# Get label.
		theta_i = theta / (2 * np.pi / self.attention.n_rotations)
		theta_i = np.int32(np.round(theta_i)) % self.attention.n_rotations
		inp_img = inp['inp_img']
		label_size = inp_img.shape[:2] + (self.attention.n_rotations,)
		label = np.zeros(label_size)
		label[p[0], p[1], theta_i] = 1
		label = label.transpose((2, 0, 1))
		label = label.reshape(1, np.prod(label.shape))
		label = torch.from_numpy(label).to(dtype=torch.float, device=out.device)

		# Get loss.
		loss = self.cross_entropy_with_logits(out, label)

		# Backpropagate.
		if backprop:
			attn_optim = self._optimizers['attn']
			self.manual_backward(loss, attn_optim)
			attn_optim.step()
			attn_optim.zero_grad()

		# Pixel and Rotation error (not used anywhere).
		err = {}
		if compute_err:
			pick_conf = self.attn_forward(inp)
			pick_conf = pick_conf.detach().cpu().numpy()
			argmax = np.argmax(pick_conf)
			argmax = np.unravel_index(argmax, shape=pick_conf.shape)
			p0_pix = argmax[:2]
			p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])

			err = {
				'dist': np.linalg.norm(np.array(p) - p0_pix, ord=1),
				'theta': np.absolute((theta - p0_theta) % np.pi)
			}
		return loss, err

	def trans_forward(self, inp, softmax=True):
		inp_img = inp['inp_img']
		p0 = inp['p0']

		output = self.transport.forward(inp_img, p0, softmax=softmax)
		return output

	def transport_training_step(self, frame, backprop=True, compute_err=False):
		inp_img = frame['img']
		p0 = frame['p0']
		p1, p1_theta = frame['p1'], frame['p1_theta']

		inp = {'inp_img': inp_img, 'p0': p0}
		output = self.trans_forward(inp, softmax=False)
		err, loss = self.transport_criterion(backprop, compute_err, inp, output, p0, p1, p1_theta)
		return loss, err

	def transport_criterion(self, backprop, compute_err, inp, output, p, q, theta):
		itheta = theta / (2 * np.pi / self.transport.n_rotations)
		itheta = np.int32(np.round(itheta)) % self.transport.n_rotations

		# Get one-hot pixel label map.
		inp_img = inp['inp_img']
		label_size = inp_img.shape[:2] + (self.transport.n_rotations,)
		label = np.zeros(label_size)
		label[q[0], q[1], itheta] = 1

		# Get loss.
		label = label.transpose((2, 0, 1))
		label = label.reshape(1, np.prod(label.shape))
		label = torch.from_numpy(label).to(dtype=torch.float, device=output.device)
		output = output.reshape(1, np.prod(output.shape))
		loss = self.cross_entropy_with_logits(output, label)
		if backprop:
			transport_optim = self._optimizers['trans']
			self.manual_backward(loss, transport_optim)
			transport_optim.step()
			transport_optim.zero_grad()
 
		# Pixel and Rotation error (not used anywhere).
		err = {}
		if compute_err:
			place_conf = self.trans_forward(inp)
			place_conf = place_conf.permute(1, 2, 0)
			place_conf = place_conf.detach().cpu().numpy()
			argmax = np.argmax(place_conf)
			argmax = np.unravel_index(argmax, shape=place_conf.shape)
			p1_pix = argmax[:2]
			p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])

			err = {
				'dist': np.linalg.norm(np.array(q) - p1_pix, ord=1),
				'theta': np.absolute((theta - p1_theta) % np.pi)
			}
		self.transport.iters += 1
		return err, loss

	def training_step(self, batch, batch_idx):
		self.attention.train()
		self.transport.train()

		frame, _ = batch

		# Get training losses.
		step = self.total_steps + 1
		loss0, err0 = self.attn_training_step(frame)
		if isinstance(self.transport, Attention):
			loss1, err1 = self.attn_training_step(frame)
		else:
			loss1, err1 = self.transport_training_step(frame)
		total_loss = loss0 + loss1
		# if self.total_steps %500 == 0:
		self.log('tr/attn/loss', loss0)
		self.log('tr/trans/loss', loss1)
		self.log('tr/loss', total_loss)
		self.total_steps = step



		# import ipdb;ipdb.set_trace()
		# h,w,c = frame['img'][:,:,:3].shape
		# rgb = np.reshape(frame['img'][:,:,:3],[1, h, w,3])
		# depth = np.reshape(frame['img'][:,:,3],[1, h, w, 1])
		# depth = (depth - min(depth))/(max(depth) - min(depth))

		# feature0 = (feature0 - tf.reduce_min(feature0))/(tf.reduce_max(feature0) - tf.reduce_min(feature0))
		# pick_feat = tf.reshape(feature0,[1, h, w,1])
		# # tf.unravel_index(indices, dims, name=None)
		# angle = tf.math.argmax(tf.reshape(feature1,(h*w,36)),axis=0)
		# place_feat = tf.reshape(feature1[:,:,:,tf.math.argmax(angle)],[1,h,w,1])
		# place_feat = (place_feat - tf.reduce_min(place_feat))/(tf.reduce_max(place_feat) - tf.reduce_min(place_feat))

		# if self.total_steps %1000 == 0:
		# 	self.log('tr/rgb', rgb)
		# 	self.log('tr/depth', depth)  
			# tf.summary.image("pick_feat", pick_feat, step=step)
			# tf.summary.image("place_feat",place_feat, step=step)
			



		self.trainer.train_loop.running_loss.append(total_loss)

		self.check_save_iteration()

		return dict(
			loss=total_loss,
		)

	def check_save_iteration(self):
		global_step = self.trainer.global_step
		if (global_step + 1) in self.save_steps:
			self.trainer.run_evaluation()
			val_loss = self.trainer.callback_metrics['val_loss']
			steps = f'{global_step + 1:05d}'
			filename = f"steps={steps}-val_loss={val_loss:0.8f}.ckpt"
			checkpoint_path = os.path.join(self.cfg['train']['train_dir'], 'checkpoints')
			ckpt_path = os.path.join(checkpoint_path, filename)
			self.trainer.save_checkpoint(ckpt_path)

		if (global_step + 1) % 1000 == 0:
			# save lastest checkpoint
			# print(f"Saving last.ckpt Epoch: {self.trainer.current_epoch} | Global Step: {self.trainer.global_step}")
			self.save_last_checkpoint()

	def save_last_checkpoint(self):
		checkpoint_path = os.path.join(self.cfg['train']['train_dir'], 'checkpoints')
		ckpt_path = os.path.join(checkpoint_path, 'last.ckpt')
		self.trainer.save_checkpoint(ckpt_path)

	def validation_step(self, batch, batch_idx):
		self.attention.eval()
		self.transport.eval()

		loss0, loss1 = 0, 0
		assert self.val_repeats >= 1
		for i in range(self.val_repeats):
			frame, _ = batch
			l0, err0 = self.attn_training_step(frame, backprop=False, compute_err=True)
			loss0 += l0
			if isinstance(self.transport, Attention):
				l1, err1 = self.attn_training_step(frame, backprop=False, compute_err=True)
				loss1 += l1
			else:
				l1, err1 = self.transport_training_step(frame, backprop=False, compute_err=True)
				loss1 += l1
		loss0 /= self.val_repeats
		loss1 /= self.val_repeats
		val_total_loss = loss0 + loss1

		self.trainer.evaluation_loop.trainer.train_loop.running_loss.append(val_total_loss)

		return dict(
			val_loss=val_total_loss,
			val_loss0=loss0,
			val_loss1=loss1,
			val_attn_dist_err=err0['dist'],
			val_attn_theta_err=err0['theta'],
			val_trans_dist_err=err1['dist'],
			val_trans_theta_err=err1['theta'],
		)

	def training_epoch_end(self, all_outputs):
		super().training_epoch_end(all_outputs)
		utils.set_seed(self.trainer.current_epoch+1)

	def validation_epoch_end(self, all_outputs):
		mean_val_total_loss = np.mean([v['val_loss'].item() for v in all_outputs])
		mean_val_loss0 = np.mean([v['val_loss0'].item() for v in all_outputs])
		mean_val_loss1 = np.mean([v['val_loss1'].item() for v in all_outputs])
		total_attn_dist_err = np.sum([v['val_attn_dist_err'] for v in all_outputs])
		total_attn_theta_err = np.sum([v['val_attn_theta_err'] for v in all_outputs])
		total_trans_dist_err = np.sum([v['val_trans_dist_err'] for v in all_outputs])
		total_trans_theta_err = np.sum([v['val_trans_theta_err'] for v in all_outputs])

		self.log('vl/attn/loss', mean_val_loss0)
		self.log('vl/trans/loss', mean_val_loss1)
		self.log('vl/loss', mean_val_total_loss)
		self.log('vl/total_attn_dist_err', total_attn_dist_err)
		self.log('vl/total_attn_theta_err', total_attn_theta_err)
		self.log('vl/total_trans_dist_err', total_trans_dist_err)
		self.log('vl/total_trans_theta_err', total_trans_theta_err)

		print("\nAttn Err - Dist: {:.2f}, Theta: {:.2f}".format(total_attn_dist_err, total_attn_theta_err))
		print("Transport Err - Dist: {:.2f}, Theta: {:.2f}".format(total_trans_dist_err, total_trans_theta_err))

		return dict(
			val_loss=mean_val_total_loss,
			val_loss0=mean_val_loss0,
			mean_val_loss1=mean_val_loss1,
			total_attn_dist_err=total_attn_dist_err,
			total_attn_theta_err=total_attn_theta_err,
			total_trans_dist_err=total_trans_dist_err,
			total_trans_theta_err=total_trans_theta_err,
		)

	def draw_heat(self, pick_conf, place_conf,img_color):

		fig, axs = plt.subplots(3, 1, figsize=(15, 10),squeeze=False)

		axs[0][0].imshow(img_color.transpose(1,0,2)/255)
		axs[1][0].imshow(img_color.transpose(1,0,2)/255)
		axs[2][0].imshow(img_color.transpose(1,0,2)/255)

		scale = 30 # amplify
		pick_logits_disp = np.uint8(pick_conf * 255 * scale).transpose(1,0,2)
		place_logits_disp = np.uint8(np.sum(place_conf, axis=2)[:,:,None] * 255 * scale).transpose(1,0,2)  # sum across rotations

		pick_logits_disp_masked = np.ma.masked_where(pick_logits_disp < 0, pick_logits_disp)
		place_logits_disp_masked = np.ma.masked_where(place_logits_disp < 0, place_logits_disp)

		axs[1][0].imshow(pick_logits_disp_masked, cmap='viridis', alpha=0.75)
		axs[2][0].imshow(place_logits_disp_masked, cmap='viridis', alpha=0.75)


		plt.savefig('{}.png'.format(str(self.cnt)))
		# plt.show()
		self.cnt += 1
		# plt.savefig('heat_map.png')

	def actdraw(self, obs, p0_xyz, p1_xyz):
		img = self.test_ds.get_image(obs)

		# Attention model forward pass.
		pick_inp = {'inp_img': img}
		pick_conf = self.attn_forward(pick_inp)
		pick_conf = pick_conf.detach().cpu().numpy()
		argmax = np.argmax(pick_conf)
		argmax = np.unravel_index(argmax, shape=pick_conf.shape)
		p0_pix = argmax[:2]
		p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])


		p0_pix_fake = utils.xyz_to_pix(p0_xyz,self.bounds,self.pix_size)
		placexy = utils.xyz_to_pix(p1_xyz,self.bounds,self.pix_size)


		# Transport model forward pass.
		place_inp = {'inp_img': img, 'p0': p0_pix}
		place_conf = self.trans_forward(place_inp)
		place_conf = place_conf.permute(1, 2, 0)
		place_conf = place_conf.detach().cpu().numpy()
		argmax = np.argmax(place_conf)
		argmax = np.unravel_index(argmax, shape=place_conf.shape)
		p1_pix = argmax[:2]
		p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])

		# Pixels to end effector poses.
		hmap = img[:, :, 3]
		p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size)
		p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size)
		p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
		p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))
		# import ipdb;ipdb.set_trace()
		pick_conf[p0_pix_fake] = 1
		for i in range(1,5):
			for j in range(1,5):
				pick_conf[p0_pix_fake[0]-5+i][p0_pix_fake[1]-5+j] = 0.1*(i-j)*(i-j)

		place_conf[placexy] = 1
		# ttt = place_conf.max() * 0.5
		# for i in range(1,5):
		# 	for j in range(1,5):
		# 		place_conf[placexy[0]-5+i][placexy[1]-5+j][:] = ttt
		self.draw_heat(pick_conf,place_conf, img[:,:,:3])
		return {
			'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
			'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)),
			'pick': p0_pix,
			'place': p1_pix,
		}

	def act(self, obs, info=None, goal=None,p0_heat=None,p1_heat= None):  # pylint: disable=unused-argument
		"""Run inference and return best action given visual observations."""
		# Get heightmap from RGB-D images.
		img = self.test_ds.get_image(obs)
		# import ipdb;ipdb.set_trace()
		# Attention model forward pass.
		pick_inp = {'inp_img': img}
		pick_conf = self.attn_forward(pick_inp)
		pick_conf = pick_conf.detach().cpu().numpy()

		pick_conf[int(p0_heat[0]),int(p0_heat[1])] += 0.1

		argmax = np.argmax(pick_conf)
		argmax = np.unravel_index(argmax, shape=pick_conf.shape)
		p0_pix = argmax[:2]
		p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])

		# Transport model forward pass.
		place_inp = {'inp_img': img, 'p0': p0_pix}
		place_conf = self.trans_forward(place_inp)
		place_conf = place_conf.permute(1, 2, 0)
		place_conf = place_conf.detach().cpu().numpy()

		place_conf[int(p1_heat[0]),int(p1_heat[1])] += 0.1

		argmax = np.argmax(place_conf)
		argmax = np.unravel_index(argmax, shape=place_conf.shape)
		p1_pix = argmax[:2]
		p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])

		# Pixels to end effector poses.
		hmap = img[:, :, 3]
		p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size)
		p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size)
		p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
		p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))
		# import ipdb;ipdb.set_trace()
		# self.draw_heat(pick_conf,place_conf, img[:,:,:3])
		return {
			'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
			'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)),
			'pick': p0_pix,
			'place': p1_pix,
		}

	def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs):
		pass

	def configure_optimizers(self):
		pass

	def train_dataloader(self):
		return self.train_ds

	def val_dataloader(self):
		return self.test_ds

	def load(self, model_path):
		self.load_state_dict(torch.load(model_path)['state_dict'])
		self.to(device=self.device_type)


class OriginalTransporterAgent(TransporterAgent):

	def __init__(self, name, cfg, train_ds, test_ds):
		super().__init__(name, cfg, train_ds, test_ds)

	def _build_model(self):
		stream_fcn = 'plain_resnet'
		self.attention = Attention(
			stream_fcn=(stream_fcn, None),
			in_shape=self.in_shape,
			n_rotations=1,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)
		self.transport = Transport(
			stream_fcn=(stream_fcn, None),
			in_shape=self.in_shape,
			n_rotations=self.n_rotations,
			crop_size=self.crop_size,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)


class ClipUNetTransporterAgent(TransporterAgent):

	def __init__(self, name, cfg, train_ds, test_ds):
		super().__init__(name, cfg, train_ds, test_ds)

	def _build_model(self):
		stream_fcn = 'clip_unet'
		self.attention = Attention(
			stream_fcn=(stream_fcn, None),
			in_shape=self.in_shape,
			n_rotations=1,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)
		self.transport = Transport(
			stream_fcn=(stream_fcn, None),
			in_shape=self.in_shape,
			n_rotations=self.n_rotations,
			crop_size=self.crop_size,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)


class TwoStreamClipUNetTransporterAgent(TransporterAgent):

	def __init__(self, name, cfg, train_ds, test_ds):
		super().__init__(name, cfg, train_ds, test_ds)

	def _build_model(self):
		stream_one_fcn = 'plain_resnet'
		stream_two_fcn = 'clip_unet'
		self.attention = TwoStreamAttention(
			stream_fcn=(stream_one_fcn, stream_two_fcn),
			in_shape=self.in_shape,
			n_rotations=1,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)
		self.transport = TwoStreamTransport(
			stream_fcn=(stream_one_fcn, stream_two_fcn),
			in_shape=self.in_shape,
			n_rotations=self.n_rotations,
			crop_size=self.crop_size,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)


class TwoStreamClipUNetLatTransporterAgent(TransporterAgent):

	def __init__(self, name, cfg, train_ds, test_ds):
		super().__init__(name, cfg, train_ds, test_ds)

	def _build_model(self):
		stream_one_fcn = 'plain_resnet_lat'
		stream_two_fcn = 'clip_unet_lat'
		self.attention = TwoStreamAttentionLat(
			stream_fcn=(stream_one_fcn, stream_two_fcn),
			in_shape=self.in_shape,
			n_rotations=1,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)
		self.transport = TwoStreamTransportLat(
			stream_fcn=(stream_one_fcn, stream_two_fcn),
			in_shape=self.in_shape,
			n_rotations=self.n_rotations,
			crop_size=self.crop_size,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)


class TwoStreamClipWithoutSkipsTransporterAgent(TransporterAgent):

	def __init__(self, name, cfg, train_ds, test_ds):
		super().__init__(name, cfg, train_ds, test_ds)

	def _build_model(self):
		# TODO: lateral version
		stream_one_fcn = 'plain_resnet'
		stream_two_fcn = 'clip_woskip'
		self.attention = TwoStreamAttention(
			stream_fcn=(stream_one_fcn, stream_two_fcn),
			in_shape=self.in_shape,
			n_rotations=1,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)
		self.transport = TwoStreamTransport(
			stream_fcn=(stream_one_fcn, stream_two_fcn),
			in_shape=self.in_shape,
			n_rotations=self.n_rotations,
			crop_size=self.crop_size,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)


class TwoStreamRN50BertUNetTransporterAgent(TransporterAgent):

	def __init__(self, name, cfg, train_ds, test_ds):
		super().__init__(name, cfg, train_ds, test_ds)

	def _build_model(self):
		# TODO: lateral version
		stream_one_fcn = 'plain_resnet'
		stream_two_fcn = 'rn50_bert_unet'
		self.attention = TwoStreamAttention(
			stream_fcn=(stream_one_fcn, stream_two_fcn),
			in_shape=self.in_shape,
			n_rotations=1,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)
		self.transport = TwoStreamTransport(
			stream_fcn=(stream_one_fcn, stream_two_fcn),
			in_shape=self.in_shape,
			n_rotations=self.n_rotations,
			crop_size=self.crop_size,
			preprocess=utils.preprocess,
			cfg=self.cfg,
			device=self.device_type,
		)
