import os
import shutil
import numpy as np
import clip
import lpips
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from criteria.parse_related_loss import bg_loss, average_lab_color_loss
import criteria.clip_loss as clip_loss
import criteria.image_embedding_loss as image_embedding_loss
from criteria import id_loss
from mapper.datasets.latents_dataset import LatentsDataset
from mapper.AT3D_mapper import AT3DMapper
from mapper.mapper_training.ranger import Ranger
from mapper.mapper_training import train_utils
from models.eg3d.training.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
from .train_utils import get_fourier_descriptor
from PIL import Image
from deepface import DeepFace
from progress.bar import Bar
from models.third_party.BiSeNet import FaceParser


class Coach:
	def __init__(self, opts):
		self.opts = opts
		self.global_step = 0
		self.device = 'cuda:0'
		self.opts.device = self.device

		# Initialize network
		self.net = AT3DMapper(self.opts).to(self.device)

		# Initialize loss
		self.id_loss = id_loss.IDLoss(self.opts).to(self.device).eval()
		self.clip_loss_models = {model_name: clip_loss.CLIPLoss(self.device, clip_model=model_name).to(self.device) 
                                for model_name in opts.clip_models}
		self.clip_model_weights = {model_name: weight for model_name, weight in zip(opts.clip_models, opts.clip_model_weights)}

		self.mask_loss = clip_loss.MaskLoss(self.opts).to(self.device)
		self.background_loss = bg_loss.BackgroundLoss(self.opts).to(self.device).eval()
		
		# Initialize optimizer
		self.optimizer = self.configure_optimizers()
  
		self.parsenet = FaceParser(model_path=opts.face_parser_ckpt, device=self.device)
  
		# backup codes
		backup_dir = os.path.join(opts.exp_dir, 'backup')
		os.makedirs(backup_dir, exist_ok=True)
		shutil.copytree('../criteria', backup_dir+'/criteria')
		os.makedirs(backup_dir+'/mapper', exist_ok=True)
		shutil.copyfile('./latent_mappers.py', backup_dir+'/mapper/latent_mappers.py')
		shutil.copyfile('./attribute_list.py', backup_dir+'/mapper/attribute_list.py')
		shutil.copyfile('./AT3D_mapper.py', backup_dir+'/mapper/AT3D_mapper.py')
  
  
		self.pitch_range = 0.25
		self.yaw_range = 0.35
		self.cam_pivot = torch.Tensor(self.opts.avg_camera_pivot).to(self.device)
		self.cam_radius = self.opts.avg_camera_radius
		self.intrinsics = FOV_to_intrinsics(self.opts.fov_deg, device=self.device)
  
		z = torch.randn(10000, 512, device=self.device)
		c = []
		for i in range(10000):
			cam2world_pose = LookAtPoseSampler.sample(3.14/2 + self.yaw_range * np.sin(2 * 3.14 * np.random.rand()),
													3.14/2 -0.05 + self.pitch_range * np.cos(2 * 3.14 * np.random.rand()),
													self.cam_pivot, radius=self.cam_radius, device=self.device)
			c_i = torch.cat([cam2world_pose.reshape(-1, 16), self.intrinsics.reshape(-1, 9)], 1)
			c.append(c_i)
		c = torch.cat(c, 0)

		with torch.no_grad():
			self.mean_latent = self.net.decoder.mapping(z, c).mean(0)
		

		# Initialize dataset
		self.train_dataset, self.test_dataset = self.configure_datasets()
		self.train_dataloader = DataLoader(self.train_dataset,
										   batch_size=self.opts.batch_size,
										   shuffle=True,
										   # num_workers=int(self.opts.workers),
										   num_workers=0,
										   drop_last=True)
		self.test_dataloader = DataLoader(self.test_dataset,
										  batch_size=self.opts.test_batch_size,
										  shuffle=False,
										  # num_workers=int(self.opts.test_workers),
										  num_workers=0,
										  drop_last=True)

		# Initialize logger
		log_dir = os.path.join(opts.exp_dir, 'logs')
		os.makedirs(log_dir, exist_ok=True)
		self.log_dir = log_dir
		self.logger = SummaryWriter(log_dir=log_dir)

		# Initialize checkpoint dir
		self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
		os.makedirs(self.checkpoint_dir, exist_ok=True)
		self.best_val_loss = None
		if self.opts.save_interval is None:
			self.opts.save_interval = self.opts.max_steps

	def train(self):
		self.net.train()
		while self.global_step < self.opts.max_steps:
			for batch_idx, batch in enumerate(self.train_dataloader):
				self.optimizer.zero_grad()
				src_latent, src_pose, selected_attributes, selected_description_tuple, ref_img, ref_pose, ref_latent = batch
				selected_description = ''
				for item in selected_description_tuple:
					selected_description+=item

				src_latent = src_latent.to(self.device)
				src_pose = src_pose.to(self.device)
				ref_img = ref_img.to(self.device)
				ref_pose = ref_pose.to(self.device)
				ref_latent = ref_latent.to(self.device)
				selected_attributes = selected_attributes[0].to(self.device)
				with torch.no_grad():
					src_img = self.net.decoder.synthesis(src_latent, src_pose)['image']
     
				tgt_latent, logits = self.net.mapper(src_latent, ref_latent, selected_attributes)
				tgt_img = self.net.decoder.synthesis(tgt_latent, src_pose)['image']
    
				loss, loss_dict = self.calc_loss(src_latent, src_img, tgt_latent, tgt_img, ref_img, ref_pose, selected_description, selected_attributes, logits)
				loss.backward()
				self.optimizer.step()

				# Logging related
				# if self.global_step % self.opts.image_interval == 0 or (
				# 		self.global_step < 1000 and self.global_step % 1000 == 0):
				# 	self.parse_and_log_images(src_img, tgt_img, ref_img, title='images_train', selected_description=selected_description)
				if self.global_step % self.opts.board_interval == 0:
					self.print_metrics(loss_dict, prefix='train', selected_description=selected_description)
					self.log_metrics(loss_dict, prefix='train')

				# Validation related
				val_loss_dict = None
				if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
				# if self.global_step == self.opts.max_steps:
					val_loss_dict = self.validate()
					if val_loss_dict and (self.best_val_loss is None or val_loss_dict['loss'] < self.best_val_loss):
						self.best_val_loss = val_loss_dict['loss']
						self.checkpoint_me(val_loss_dict, is_best=True)

				if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
				# if self.global_step == self.opts.max_steps:
					if val_loss_dict is not None:
						self.checkpoint_me(val_loss_dict, is_best=False)
					else:
						self.checkpoint_me(loss_dict, is_best=False)

				if self.global_step == self.opts.max_steps:
					print('OMG, finished training!', flush=True)
					break

				self.global_step += 1

	def validate(self):
		self.net.eval()
		agg_loss_dict = []
		for batch_idx, batch in enumerate(self.test_dataloader):
			if batch_idx > 200:
				break

			src_latent, src_pose, selected_attributes, selected_description_tuple, ref_img, ref_pose, ref_latent = batch
			selected_description = ''
			for item in selected_description_tuple:
				selected_description+=item

			with torch.no_grad():
				src_latent = src_latent.to(self.device).float()
				src_pose = src_pose.to(self.device).float()
				ref_img = ref_img.to(self.device)
				ref_pose = ref_pose.to(self.device)
				ref_latent = ref_latent.to(self.device)
				selected_attributes = selected_attributes[0].to(self.device)
				
				src_img = self.net.decoder.synthesis(src_latent, src_pose)['image']
				tgt_latent, logits = self.net.mapper(src_latent, ref_latent, selected_attributes)
				tgt_img = self.net.decoder.synthesis(tgt_latent, src_pose)['image']
				loss, cur_loss_dict = self.calc_loss(src_latent, src_img, tgt_latent, tgt_img, ref_img, ref_pose, selected_description, selected_attributes, logits)
				
			agg_loss_dict.append(cur_loss_dict)

			# Logging related
			# self.parse_and_log_images(src_img, tgt_img, ref_img, title='images_val', selected_description=selected_description, index=batch_idx)

			# For first step just do sanity test on small amount of data
			if self.global_step == 0 and batch_idx >= 4:
				self.net.train()
				return None  # Do not log, inaccurate in first batch

		loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
		self.log_metrics(loss_dict, prefix='test')
		self.print_metrics(loss_dict, prefix='test', selected_description=selected_description)

		self.net.train()
		return loss_dict

	def checkpoint_me(self, loss_dict, is_best):
		save_name = 'best_model.pt' if is_best else 'latest_model.pt'
		save_dict = self.__get_save_dict()
		checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
		torch.save(save_dict, checkpoint_path)
		with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
			if is_best:
				f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
			else:
				f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))

	def configure_optimizers(self):
		params = list(self.net.mapper.parameters())
		if self.opts.optim_name == 'adam':
			optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
		else:
			optimizer = Ranger(params, lr=self.opts.learning_rate)
		return optimizer

	def configure_datasets(self):
		# ---- train dataset ----
		train_poses = []
		train_latents = []
		train_infos = []
		cnt = 0
		bar = Bar('Processing train dataset', max=self.opts.train_dataset_size)
		while cnt < self.opts.train_dataset_size:
			z = torch.randn(1, 512).to(self.device)
			cam2world_pose = LookAtPoseSampler.sample(3.14/2 + self.yaw_range * np.sin(2 * 3.14 * np.random.rand()),
													3.14/2 -0.05 + self.pitch_range * np.cos(2 * 3.14 * np.random.rand()),
													self.cam_pivot, radius=self.cam_radius, device=self.device)
			c = torch.cat([cam2world_pose.reshape(-1, 16), self.intrinsics.reshape(-1, 9)], 1)

			with torch.no_grad():
				train_latents_b = self.net.decoder.mapping(z, c)
				train_latents_b = self.opts.truncation_psi * train_latents_b + (1-self.opts.truncation_psi) * self.mean_latent.unsqueeze(0)
				train_latents_s = self.net.decoder.get_styles(train_latents_b)
				latent = pad_sequence(train_latents_s, batch_first=True, padding_value=0)[None]
	
				img = self.net.decoder.synthesis(latent, c)['image']
	
				mask = self.parsenet.batch_run(img, pre_normalize=True, image_repr=False, compact_mask=True)
				if mask is None:
					continue
				logits = torch.unsqueeze(torch.max(mask, 1)[1], 1)
				# filter out images with hat
				if (logits == 11).sum() > 0:
					continue
				# ensure a certain proportion of images with glasses
				if(cnt > self.opts.train_dataset_size*0.5 and (logits == 3).sum() == 0):
					continue
	
				torchvision.utils.save_image(torch.cat([img.detach().cpu()]), 'test.png',
									normalize=True, scale_each=True, range=(-1, 1), nrow=3)
				try:
					train_info = DeepFace.analyze('test.png', actions = ['age', 'gender'], silent=True)
				except:
					continue

			train_poses.append(c)
			train_latents.append(latent)
			train_infos.append(train_info)
			cnt += 1
			bar.next()
      
		bar.finish()
		train_poses = torch.cat(train_poses)
		train_latents = torch.cat(train_latents)
		# train_infos = torch.cat(train_infos)
		
		# ---- test dataset ----
		test_poses = []
		test_latents = []
		test_infos = []
		cnt = 0
		bar = Bar('Processing test dataset', max=self.opts.test_dataset_size)
		while cnt < self.opts.test_dataset_size:
			z = torch.randn(1, 512).to(self.device)
			cam2world_pose = LookAtPoseSampler.sample(3.14/2 + self.yaw_range * np.sin(2 * 3.14 * np.random.rand()),
													3.14/2 -0.05 + self.pitch_range * np.cos(2 * 3.14 * np.random.rand()),
													self.cam_pivot, radius=self.cam_radius, device=self.device)
			c = torch.cat([cam2world_pose.reshape(-1, 16), self.intrinsics.reshape(-1, 9)], 1)

			with torch.no_grad():
				test_latents_b = self.net.decoder.mapping(z, c)
				test_latents_b = self.opts.truncation_psi * test_latents_b + (1-self.opts.truncation_psi) * self.mean_latent.unsqueeze(0)
				test_latents_s = self.net.decoder.get_styles(test_latents_b)
				latent = pad_sequence(test_latents_s, batch_first=True, padding_value=0)[None]
				
				img = self.net.decoder.synthesis(latent, c)['image']
	
				mask = self.parsenet.batch_run(img, pre_normalize=True, image_repr=False, compact_mask=True)
				if mask is None:
					continue
				logits = torch.unsqueeze(torch.max(mask, 1)[1], 1)
				# filter out images with hat
				if (logits == 11).sum() > 0:
					continue
				# ensure a certain proportion of images with glasses
				if(cnt > self.opts.test_dataset_size*0.7 and (logits == 3).sum() == 0):
					continue
	
				torchvision.utils.save_image(torch.cat([img.detach().cpu()]), 'test.png',
									normalize=True, scale_each=True, range=(-1, 1), nrow=3)
				try:
					test_info = DeepFace.analyze('test.png', actions = ['age', 'gender'], silent=True)
				except:
					continue
	
			test_poses.append(c)
			test_latents.append(latent)
			test_infos.append(test_info)
			cnt += 1
			bar.next()
	
		bar.finish()
		test_poses = torch.cat(test_poses)
		test_latents = torch.cat(test_latents)
		# test_infos = torch.cat(test_infos)
  

		train_dataset_celeba = LatentsDataset(latents=train_latents.cpu(),
											  poses=train_poses.cpu(),
											  infos=train_infos,
											  opts=self.opts,
		                                      status='train',
                                        	  net=self.net,
                                           	  device=self.device)
		test_dataset_celeba = LatentsDataset(latents=test_latents.cpu(),
											  poses=test_poses.cpu(),
											  infos=test_infos,
											  opts=self.opts,
		                                      status='test',
                                        	  net=self.net,
                                           	  device=self.device)
		train_dataset = train_dataset_celeba
		test_dataset = test_dataset_celeba
		print("Number of training samples: {}".format(len(train_dataset)), flush=True)
		print("Number of test samples: {}".format(len(test_dataset)), flush=True)
		return train_dataset, test_dataset

	def calc_loss(self, src_latent, src_img, tgt_latent, tgt_img, ref_img, ref_pose, selected_description, selected_attributes, logits):
		loss_dict = {}
		loss = 0.0
  
		if self.opts.id_lambda > 0:
			loss_id, sim_improvement = self.id_loss(tgt_img, src_img)
			loss_dict['loss_id'] = float(loss_id)
			loss_dict['id_improve'] = float(sim_improvement)
			loss = loss_id * self.opts.id_lambda
  
		if self.opts.background_lambda > 0:
			loss_background = self.background_loss(src_img, tgt_img, selected_attributes)
			loss_dict['loss_background'] = float(loss_background)
			loss += loss_background * self.opts.background_lambda
		
		if self.opts.mask_loss_lambda > 0:
			loss_mask = self.mask_loss(logits)
			loss_dict['loss_mask'] = float(loss_mask)
			loss += loss_mask * self.opts.mask_loss_lambda
   
		if self.opts.attribute_loss_lambda > 0:
			loss_attribute = torch.sum(torch.stack([self.clip_model_weights[model_name] * self.clip_loss_models[model_name](
                    src_img, tgt_img, ref_img, target_attributes=selected_attributes) for model_name in self.clip_model_weights.keys()]))
			loss_dict['loss_attribute'] = float(loss_attribute)
			loss += loss_attribute * self.opts.attribute_loss_lambda
   
		loss_dict['loss'] = float(loss)
		return loss, loss_dict

	def log_metrics(self, metrics_dict, prefix):
		for key, value in metrics_dict.items():
			self.logger.add_scalar('{}/{}'.format(prefix, key), value, self.global_step)

	def print_metrics(self, metrics_dict, prefix, selected_description):
		if prefix == 'train':
			print('Metrics for {}, step {}'.format(prefix, self.global_step), selected_description, flush=True)
		else:
			print('Metrics for {}, step {}'.format(prefix, self.global_step), flush=True)
		for key, value in metrics_dict.items():
			print('\t{} = '.format(key), value, flush=True)

	def parse_and_log_images(self, src_img, tgt_img, ref_img, title, selected_description, index=None):
		if index is None:
			path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}-{selected_description}.jpg')
		else:
			path = os.path.join(self.log_dir, title, f'{str(self.global_step).zfill(5)}-{str(index).zfill(5)}-{selected_description}.jpg')
		os.makedirs(os.path.dirname(path), exist_ok=True)
		torchvision.utils.save_image(torch.cat([src_img.detach().cpu(), ref_img.detach().cpu(), tgt_img.detach().cpu()]), path,
									 normalize=True, scale_each=True, range=(-1, 1), nrow=3)	

	def __get_save_dict(self):
		save_dict = {
			'state_dict': self.net.state_dict(),
			'opts': vars(self.opts)
		}
		return save_dict