# coding: utf-8

import argparse,os,json,math
import torch
import torch.nn.functional as F
from utils.training_template import Tester
from utils.logging import get_logger
from utils.vq_test_metrics import CodePerplexity,CodeCoverage
from data.dataset import get_audio_dataset

class TaskSpecificMethods(object):
	def test(self, dataloader, modules, save_path, **kwargs):
		code_perplexity1 =CodePerplexity(device=self.device)
		code_coverage1 = CodeCoverage(device=self.device)
		code_perplexity2 =CodePerplexity(device=self.device)
		code_coverage2 = CodeCoverage(device=self.device)
		datasize = len(dataloader.dataset)

		wav2vec2 = modules.get('wav2vec2').get('module')
		quantizer1 = modules.get('quantizer1').get('module')
		quantizer2 = modules.get('quantizer2').get('module')
		for batch in dataloader:
			waveform,wav_lengths,fs,target,target_lengths,*args = batch
			waveform = waveform.to(self.device)
			target = target.to(self.device)

			# with torch.autograd.detect_anomaly():
			# with torch.autocast(self.device.type, dtype=torch.bfloat16):
			continuous,*_,latent_lengths = wav2vec2(waveform,wav_lengths)

			continuous = continuous.transpose(1,2).unsqueeze(-1) # BxLxD -> BxDxLx1, with dummy width dimension for conv2d
			quantized1,code1,reg_loss1,stats1 = quantizer1(continuous)
			quantized2,code2,reg_loss2,stats2 = quantizer2(continuous)

			# Trip off paddings
			code1 = code1.squeeze(-1).transpose(1,2)
			B,L,K = code1.size()
			unmask = latent_lengths.view(B,1,1)>torch.arange(L, device=latent_lengths.device).view(1,L,1)
			code1 = code1.masked_select(unmask).view(-1,K,1,1) # BxLxK -> #samples x K x 1 x 1, with dummy height & width dimension

			code2 = code2.squeeze(-1).transpose(1,2)
			code2 = code2.masked_select(unmask).view(-1,K,1,1)

			code_perplexity1.update(code1)
			code_coverage1.update(code1)
			code_perplexity2.update(code2)
			code_coverage2.update(code2)

		os.makedirs(os.path.dirname(save_path), exist_ok=True)
		with open(save_path, 'w') as f:
			json.dump(dict(
					code_use_counts_1=code_perplexity1._counts.tolist(),
					code_perplexity_normalized_1=code_perplexity1.compute(),
					code_coverage_normalized_1=code_coverage1.compute(),
					code_use_counts_2=code_perplexity2._counts.tolist(),
					code_perplexity_normalized_2=code_perplexity2.compute(),
					code_coverage_normalized_2=code_coverage2.compute(),
					datasize=datasize
				),
				f)


class TaskTester(TaskSpecificMethods, Tester):
	pass



if __name__=='__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('data_name', type=str, choices=['LibriSpeech'], help='Name of dataset.')
	parser.add_argument('data_root', type=str, help='Path to the directory where data are stored.')
	parser.add_argument('checkpoint_path', type=str, help='Path to the checkpoint file containing trained model parameters.')
	parser.add_argument('save_path', type=str, help='Path to the json file where results are saved.')

	parser.add_argument('--split', type=str, required=True, choices=['dev-clean','test-clean','dev-other','test-other','train-clean-100'], help='Specify data split by "test-clean" etc.')

	parser.add_argument('--batch_size', type=int, default=1, help='Batch size.')

	parser.add_argument('--num_workers', type=int, default=0, help='# of dataloading workers.')

	parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda.')
	parser.add_argument('--seed', type=int, default=111, help='Random seed.')
	# parser.add_argument('--ddp', action='store_true', help='Use DistributedDataParallel instead of DataParallel.')

	args = parser.parse_args()

	logger = get_logger()

	logger.info('Test ASR on {}.'.format(args.data_name))

	tester = TaskTester(logger, args.checkpoint_path, device=args.device)
	dataset = get_audio_dataset(args.data_name, args.data_root, split=args.split, max_length=None)
	tester(dataset, args.batch_size, args.num_workers, save_path=args.save_path)