# coding: utf-8

import math,argparse,os
import torch
from utils.training_template import DDPLearner,Learner
from utils.logging import get_logger
from utils.yaml_parser import parse_yaml
from data.dataset import get_audio_dataset

class TaskSpecificMethods(object):
	def train_per_iteration(self, batch, records, modules, device, rank):
		waveform,wav_lengths,*args = batch
		waveform = waveform.to(device, non_blocking=True)
		wav_lengths = wav_lengths.to(device, non_blocking=True)

		wav2vec2 = modules.get('wav2vec2').get('module')
		quantizer = modules.get('quantizer').get('module')
		contrast_loss_func = modules.get('contrast_loss').get('module')

		# with torch.autograd.detect_anomaly():
		with torch.autocast(device.type, dtype=torch.bfloat16):
			continuous,encoded,is_target,latent_lengths = wav2vec2(waveform,wav_lengths)
			continuous = continuous.transpose(1,2).unsqueeze(-1) # BxLxD -> BxDxLx1, with dummy width dimension for conv2d
			quantized,code,reg_loss,stats = quantizer(continuous)
			quantized = quantized.squeeze(-1).transpose(1,2) # BxDxLx1 -> BxLxD

			total_loss = 0.0
			if not reg_loss is None:
				for name,value in reg_loss.items():
					value = value.mean()
					if self.is_record_proc(rank=rank):
						self.update_records(records, name, value.item())
					if (not self.checkpoint['loss_weights'] is None) and (name in self.checkpoint['loss_weights']):
						value = value*self.checkpoint['loss_weights'][name]
					total_loss += value
			
			if (not stats is None) and self.is_record_proc(rank=rank):
				code_freq = stats.pop('code_freq', None)
				if not code_freq is None:
					codebook_size = code.size(1)
					# NOTE: code_freq is in codebook_size if DDP is used; otherwise, it's of size code_freq*#GPUs
					self.update_records(records, 'code_freq', code_freq.view(codebook_size,-1).sum(dim=1).data.cpu().numpy())
				for name,value in stats.items():
					self.update_records(records, name, value.mean().item())

			contrast_loss = contrast_loss_func(encoded,quantized,is_target)
			if self.is_record_proc(rank=rank):
				self.update_records(records, 'contrast_loss', contrast_loss.item())
			total_loss += contrast_loss

		total_loss.backward()
		self.update_params('wav2vec2', modules=modules)
		self.update_params('quantizer', modules=modules)

	def log_training_stats(self, records, num_iters_per_epoch):
		self.logger.info('Contrastive loss (perplexity): {:0.6f}'.format(math.exp(records['contrast_loss']/num_iters_per_epoch)))
		self.logger.info('Code-frequency entropy per sample (perplexity): {:0.6f}'.format(math.exp(records['code_entropy']/num_iters_per_epoch)))
		self.logger.info('Mean # of code types used per sample: {:0.6f}'.format(records['code_coverage']/num_iters_per_epoch))
		self.logger.info('# of code types used at least once since previous save: {:0.6f}'.format((records['code_freq']>0).sum()))
		if 'entropy' in records:
			self.logger.info('Classification entropy (perplexity): {:0.6f}'.format(math.exp(records['entropy']/num_iters_per_epoch)))
		if 'knn_l2' in records:
			self.logger.info('KNN L2 distance to onehot vectors: {:0.6f}'.format(records['knn_l2']/num_iters_per_epoch))
		if 'knn_ce' in records:
			self.logger.info('KNN cross entropy w/ respect to onehot vectors (perplexity): {:0.6f}'.format(math.exp(records['knn_ce']/num_iters_per_epoch)))
		if 'neg_global_perplexity' in records:
			self.logger.info('Global classification perplexity (normalized): {:0.6f}'.format(-records['neg_global_perplexity']/num_iters_per_epoch))
		if 'commitment_loss_l2' in records:
			self.logger.info('Commitment loss (rMSE): {:0.6f}'.format(math.sqrt(max(records['commitment_loss_l2']/num_iters_per_epoch,0.0))))
		if 'commitment_loss_dot' in records:
			self.logger.info('Commitment loss (dot): {:0.6f}'.format(records['commitment_loss_dot']/num_iters_per_epoch))
		if 'code_location_loss_l2' in records:
			self.logger.info('Code-location loss (rMSE): {:0.6f}'.format(math.sqrt(max(records['code_location_loss_l2']/num_iters_per_epoch),0.0)))
		if 'code_location_loss_dot' in records:
			self.logger.info('Code-location loss (dot): {:0.6f}'.format(records['code_location_loss_dot']/num_iters_per_epoch))


class DDPTaskLearner(TaskSpecificMethods, DDPLearner):
	pass

class TaskLearner(TaskSpecificMethods, Learner):
	pass



if __name__=='__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('data_name', type=str, choices=['LibriSpeech'], help='Name of dataset.')
	parser.add_argument('data_root', type=str, help='Path to the directory where data are stored.')
	parser.add_argument('configs', type=str, nargs='+', help='Path to .yaml files where model configs are specified.')
	parser.add_argument('save_dir', type=str, help='Path to the directory where results are saved.')

	parser.add_argument('--split', type=str, default='train-*', help='Specify data split by "train-clean-*" etc.')

	parser.add_argument('--num_workers', type=int, default=0, help='# of dataloading workers.')

	parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda.')
	parser.add_argument('--seed', type=int, default=111, help='Random seed.')
	parser.add_argument('--ddp', action='store_true', help='Use DistributedDataParallel instead of DataParallel.')

	args = parser.parse_args()

	os.makedirs(args.save_dir, exist_ok=True)
	logger = get_logger(os.path.join(args.save_dir, 'train.log'))

	model_configs,loss_weights,batch_size,num_epochs,extra_kwargs = parse_yaml(args.configs)
	model_configs['wav2vec2']['module_name'] += "Pretrainer"

	logger.info('wav2vec 2.0 pretraining on {}.'.format(args.data_name))	

	learner = (DDPTaskLearner if args.ddp else TaskLearner)(logger, args.save_dir, model_configs, loss_weights,
												device=args.device, seed=args.seed)
	target_transform = None
	dataset = get_audio_dataset(args.data_name, args.data_root, split=args.split, max_length=extra_kwargs['max_length'])
	learner(dataset, num_epochs, batch_size, args.num_workers)