# coding: utf-8

import math,argparse,os
import torch
from utils.training_template import DDPLearner,Learner
from utils.logging import get_logger
from utils.yaml_parser import parse_yaml
from data.dataset import get_audio_dataset

class TaskSpecificMethods(object):
	def train_per_iteration(self, batch, records, modules, device, rank):
		waveform,wav_lengths,*args = batch
		waveform = waveform.to(device, non_blocking=True)
		wav_lengths = wav_lengths.to(device, non_blocking=True)

		wav2vec2 = modules.get('wav2vec2').get('module')
		quantizer1 = modules.get('quantizer1').get('module')
		quantizer2 = modules.get('quantizer2').get('module')
		contrast_loss_func = modules.get('contrast_loss').get('module')

		# with torch.autograd.detect_anomaly():
		with torch.autocast(device.type, dtype=torch.bfloat16):
			continuous,encoded,is_target,latent_lengths = wav2vec2(waveform,wav_lengths)
			continuous = continuous.transpose(1,2).unsqueeze(-1) # BxLxD -> BxDxLx1, with dummy width dimension for conv2d
			quantized1,code1,reg_loss1,stats1 = quantizer1(continuous)
			quantized2,code2,reg_loss2,stats2 = quantizer2(continuous)
			quantized1 = quantized1.squeeze(-1).transpose(1,2) # BxDxLx1 -> BxLxD
			quantized2 = quantized2.squeeze(-1).transpose(1,2)

			total_loss = 0.0
			if not reg_loss1 is None:
				for name,value1 in reg_loss1.items():
					value1 = value1.mean()
					value2 = reg_loss2[name].mean()
					if self.is_record_proc(rank=rank):
						self.update_records(records, name+'_1', value1.item())
						self.update_records(records, name+'_2', value1.item())
					if (not self.checkpoint['loss_weights'] is None) and (name in self.checkpoint['loss_weights']):
						value1 = value1*self.checkpoint['loss_weights'][name]
						value2 = value2*self.checkpoint['loss_weights'][name]
					total_loss += (value1+value2)*0.5
			
			if (not stats1 is None) and self.is_record_proc(rank=rank):
				code_freq1 = stats1.pop('code_freq', None)
				code_freq2 = stats2.pop('code_freq', None)
				if not code_freq1 is None:
					codebook_size = code1.size(1)
					# NOTE: code_freq is in codebook_size if DDP is used; otherwise, it's of size code_freq*#GPUs
					self.update_records(records, 'code_freq_1', code_freq1.view(codebook_size,-1).sum(dim=1).data.cpu().numpy())
					self.update_records(records, 'code_freq_2', code_freq2.view(codebook_size,-1).sum(dim=1).data.cpu().numpy())
				for name,value1 in stats1.items():
					self.update_records(records, name+'_1', value1.mean().item())
					value2 = stats2[name]
					self.update_records(records, name+'_2', value2.mean().item())

			quantized_cat = torch.cat([quantized1,quantized2], dim=-1)
			contrast_loss = contrast_loss_func(encoded,quantized_cat,is_target)
			if self.is_record_proc(rank=rank):
				self.update_records(records, 'contrast_loss', contrast_loss.item())
			total_loss += contrast_loss

		total_loss.backward()
		self.update_params('wav2vec2', modules=modules)
		self.update_params('quantizer1', modules=modules)
		self.update_params('quantizer2', modules=modules)

	def log_training_stats(self, records, num_iters_per_epoch):
		self.logger.info('Contrastive loss (perplexity): {:0.6f}'.format(math.exp(records['contrast_loss']/num_iters_per_epoch)))
		for codebook_idx in range(1,3):
			self.logger.info('Code-frequency entropy per sample for Codebook {codebook_idx} (perplexity): {score:0.6f}'.format(codebook_idx=codebook_idx,score=math.exp(records['code_entropy_{}'.format(codebook_idx)]/num_iters_per_epoch)))
			self.logger.info('Mean # of code types used per sample for Codebook {codebook_idx}: {score:0.6f}'.format(codebook_idx=codebook_idx,score=records['code_coverage_{}'.format(codebook_idx)]/num_iters_per_epoch))
			self.logger.info('# of code types used at least once since previous save for Codebook {codebook_idx}: {score:0.6f}'.format(codebook_idx=codebook_idx,score=(records['code_freq_{}'.format(codebook_idx)]>0).sum()))
			if 'entropy_{}'.format(codebook_idx) in records:
				self.logger.info('Classification entropy for Codebook {codebook_idx} (perplexity): {score:0.6f}'.format(codebook_idx=codebook_idx,score=math.exp(records['entropy_{}'.format(codebook_idx)]/num_iters_per_epoch)))
			if 'knn_l2_{}'.format(codebook_idx) in records:
				self.logger.info('KNN L2 distance to onehot vectors for Codebook {codebook_idx}: {score:0.6f}'.format(codebook_idx=codebook_idx,score=records['knn_l2_{}'.format(codebook_idx)]/num_iters_per_epoch))
			if 'knn_ce_{}'.format(codebook_idx) in records:
				self.logger.info('KNN cross entropy w/ respect to onehot vectors for Codebook {codebook_idx} (perplexity): {score:0.6f}'.format(codebook_idx=codebook_idx,score=math.exp(records['knn_ce_{}'.format(codebook_idx)]/num_iters_per_epoch)))
			if 'neg_global_perplexity_{}'.format(codebook_idx) in records:
				self.logger.info('Global classification perplexity for Codebook {codebook_idx} (normalized): {score:0.6f}'.format(codebook_idx=codebook_idx,score=-records['neg_global_perplexity_{}'.format(codebook_idx)]/num_iters_per_epoch))
			if 'commitment_loss_l2_{}'.format(codebook_idx) in records:
				self.logger.info('Commitment loss for Codebook {codebook_idx} (rMSE): {score:0.6f}'.format(codebook_idx=codebook_idx,score=math.sqrt(records['commitment_loss_l2_{}'.format(codebook_idx)]/num_iters_per_epoch)))
			if 'commitment_loss_dot_{}'.format(codebook_idx) in records:
				self.logger.info('Commitment loss for Codebook {codebook_idx} (dot): {score:0.6f}'.format(codebook_idx=codebook_idx,score=records['commitment_loss_dot_{}'.format(codebook_idx)]/num_iters_per_epoch))
			if 'code_location_loss_l2_{}'.format(codebook_idx) in records:
				self.logger.info('Code-location loss for Codebook {codebook_idx} (rMSE): {score:0.6f}'.format(codebook_idx=codebook_idx,score=math.sqrt(records['code_location_loss_l2_{}'.format(codebook_idx)]/num_iters_per_epoch)))
			if 'code_location_loss_dot_{}'.format(codebook_idx) in records:
				self.logger.info('Code-location loss for Codebook {codebook_idx} (dot): {score:0.6f}'.format(codebook_idx=codebook_idx,score=records['code_location_loss_dot_{}'.format(codebook_idx)]/num_iters_per_epoch))


class DDPTaskLearner(TaskSpecificMethods, DDPLearner):
	pass

class TaskLearner(TaskSpecificMethods, Learner):
	pass



if __name__=='__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('data_name', type=str, choices=['LibriSpeech'], help='Name of dataset.')
	parser.add_argument('data_root', type=str, help='Path to the directory where data are stored.')
	parser.add_argument('configs', type=str, nargs='+', help='Path to .yaml files where model configs are specified.')
	parser.add_argument('save_dir', type=str, help='Path to the directory where results are saved.')

	parser.add_argument('--split', type=str, default='train-*', help='Specify data split by "train-clean-*" etc.')

	parser.add_argument('--num_workers', type=int, default=0, help='# of dataloading workers.')

	parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda.')
	parser.add_argument('--seed', type=int, default=111, help='Random seed.')
	parser.add_argument('--ddp', action='store_true', help='Use DistributedDataParallel instead of DataParallel.')

	args = parser.parse_args()

	os.makedirs(args.save_dir, exist_ok=True)
	logger = get_logger(os.path.join(args.save_dir, 'train.log'))

	model_configs,loss_weights,batch_size,num_epochs,extra_kwargs = parse_yaml(args.configs)
	model_configs['wav2vec2']['module_name'] += "Pretrainer"

	model_configs['quantizer1'] = model_configs.pop('quantizer')
	model_configs['quantizer2'] = model_configs['quantizer1']

	logger.info('wav2vec 2.0 pretraining on {} with dual codebook.'.format(args.data_name))	

	learner = (DDPTaskLearner if args.ddp else TaskLearner)(logger, args.save_dir, model_configs, loss_weights,
												device=args.device, seed=args.seed)
	target_transform = None
	dataset = get_audio_dataset(args.data_name, args.data_root, split=args.split, max_length=extra_kwargs['max_length'])
	learner(dataset, num_epochs, batch_size, args.num_workers)