# coding: utf-8

import math,argparse,os
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose,CenterCrop,Normalize,ToTensor,Pad
from torchvision.transforms.functional import center_crop
from utils.training_template import DDPLearner,Learner
from utils.logging import get_logger
from utils.yaml_parser import parse_yaml
from data.dataset import get_img_dataset

class TaskSpecificMethods(object):
	def train_per_iteration(self, batch, records, modules, device, rank):
		img,*args = batch
		# print(img.size())
		img = img.to(device, non_blocking=True)
		recon_target = img if len(args)==1 else args[0].to(device, non_blocking=True)

		encoder = modules.get('encoder').get('module')
		quantizer = modules.get('quantizer').get('module')
		decoder = modules.get('decoder').get('module')

		# with torch.autograd.detect_anomaly():
		with torch.autocast(device.type, dtype=torch.bfloat16):
			continuous = encoder(img)
			quantized,code,reg_loss,stats = quantizer(continuous)
			recon = decoder(quantized)

			total_loss = 0.0
			if not reg_loss is None:
				for name,value in reg_loss.items():
					value = value.mean()
					if self.is_record_proc(rank=rank):
						self.update_records(records, name, value.item())
					if (not self.checkpoint['loss_weights'] is None) and (name in self.checkpoint['loss_weights']):
						value = value*self.checkpoint['loss_weights'][name]
					total_loss += value
			
			if (not stats is None) and self.is_record_proc(rank=rank):
				code_freq = stats.pop('code_freq', None)
				if not code_freq is None:
					codebook_size = code.size(1)
					# NOTE: code_freq is in codebook_size if DDP is used; otherwise, it's of size code_freq*#GPUs
					self.update_records(records, 'code_freq', code_freq.view(codebook_size,-1).sum(dim=1).data.cpu().numpy())
				for name,value in stats.items():
					self.update_records(records, name, value.mean().item())

			if recon_target.size(-1)<recon.size(-1):
				recon = center_crop(recon, recon_target.size()[-2:])
			recon_loss = F.mse_loss(recon, recon_target, reduction='mean')
			if self.is_record_proc(rank=rank):
				self.update_records(records, 'recon_loss', recon_loss.item())
			total_loss += recon_loss

		total_loss.backward()
		self.update_params('encoder', modules=modules)
		self.update_params('quantizer', modules=modules)
		self.update_params('decoder', modules=modules)


	def log_training_stats(self, records, num_iters_per_epoch):
		self.logger.info('Reconstruction loss (rMSE): {:0.6f}'.format(math.sqrt(records['recon_loss']/num_iters_per_epoch)))
		self.logger.info('Code-frequency entropy per sample (perplexity): {:0.6f}'.format(math.exp(records['code_entropy']/num_iters_per_epoch)))
		self.logger.info('Mean # of code types used per sample: {:0.6f}'.format(records['code_coverage']/num_iters_per_epoch))
		self.logger.info('# of code types used at least once since previous save: {:0.6f}'.format((records['code_freq']>0).sum()))
		if 'entropy' in records:
			self.logger.info('Classification entropy (perplexity): {:0.6f}'.format(math.exp(records['entropy']/num_iters_per_epoch)))
		if 'knn_l2' in records:
			self.logger.info('KNN L2 distance to onehot vectors: {:0.6f}'.format(records['knn_l2']/num_iters_per_epoch))
		if 'knn_ce' in records:
			self.logger.info('KNN cross entropy w/ respect to onehot vectors (perplexity): {:0.6f}'.format(math.exp(records['knn_ce']/num_iters_per_epoch)))
		if 'neg_global_perplexity' in records:
			self.logger.info('Global classification perplexity (normalized): {:0.6f}'.format(-records['neg_global_perplexity']/num_iters_per_epoch))
		if 'commitment_loss_l2' in records:
			self.logger.info('Commitment loss (rMSE): {:0.6f}'.format(math.sqrt(records['commitment_loss_l2']/num_iters_per_epoch)))
		if 'commitment_loss_dot' in records:
			self.logger.info('Commitment loss (dot): {:0.6f}'.format(records['commitment_loss_dot']/num_iters_per_epoch))
		if 'code_location_loss_l2' in records:
			self.logger.info('Code-location loss (rMSE): {:0.6f}'.format(math.sqrt(records['code_location_loss_l2']/num_iters_per_epoch)))
		if 'code_location_loss_dot' in records:
			self.logger.info('Code-location loss (dot): {:0.6f}'.format(records['code_location_loss_dot']/num_iters_per_epoch))


class DDPTaskLearner(TaskSpecificMethods, DDPLearner):
	pass

class TaskLearner(TaskSpecificMethods, Learner):
	pass



if __name__=='__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('data_name', type=str, choices=['ImageNet','MNIST','CIFAR10'], help='Name of dataset.')
	parser.add_argument('data_root', type=str, help='Path to the directory where data are stored.')
	parser.add_argument('configs', type=str, nargs='+', help='Path to .yaml files where model configs are specified.')
	parser.add_argument('save_dir', type=str, help='Path to the directory where results are saved.')

	parser.add_argument('--preprocessed', action='store_true', help='Use preprocessed data, saved as .npy files.')

	parser.add_argument('--num_workers', type=int, default=0, help='# of dataloading workers.')

	parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda.')
	parser.add_argument('--seed', type=int, default=111, help='Random seed.')
	parser.add_argument('--ddp', action='store_true', help='Use DistributedDataParallel instead of DataParallel.')

	args = parser.parse_args()

	os.makedirs(args.save_dir, exist_ok=True)
	logger = get_logger(os.path.join(args.save_dir, 'train.log'))

	model_configs,loss_weights,batch_size,num_epochs,_ = parse_yaml(args.configs)

	logger.info('Autoencoding on {}.'.format(args.data_name))	

	learner = (DDPTaskLearner if args.ddp else TaskLearner)(logger, args.save_dir, model_configs, loss_weights,
												device=args.device, seed=args.seed)
	target_transform = None
	if args.preprocessed:
		transform = None
	elif args.data_name=='ImageNet':
		transform = Compose([
			CenterCrop(size=[256,256]),
			ToTensor(),
			Normalize(mean=torch.tensor([0.4815, 0.4578, 0.4082]), std=torch.tensor([0.2686, 0.2613, 0.2758]))
		])
	elif args.data_name=='MNIST':
		transform = Compose([
			ToTensor(),
			Pad((32-28)//2, fill=0, padding_mode='constant'), # Upscale by padding.
			# Normalize(mean=torch.tensor([0.1306]), std=torch.tensor([0.308]))
		])
		target_transform = CenterCrop(size=[28,28]) # Only reconstruct non-padded pixels
	elif args.data_name=='CIFAR10':
		transform = Compose([
			ToTensor(),
			Normalize(mean=torch.tensor([0.491, 0.482, 0.446]), std=torch.tensor([0.247, 0.243, 0.261]))
		])
	dataset = get_img_dataset(args.data_name, args.data_root, split='train',
							transform=transform, target_transform=target_transform,
							preprocessed=args.preprocessed)
	learner(dataset, num_epochs, batch_size, args.num_workers)