# coding: utf-8

import argparse,os,json
import torch
from torchvision.transforms import Compose,CenterCrop,Normalize,ToTensor,Pad
from torchvision.transforms.functional import center_crop
from ignite.metrics import FID,InceptionScore,RootMeanSquaredError
from ignite.metrics.gan.utils import InceptionModel
from utils.training_template import Tester
from utils.logging import get_logger
from utils.vq_test_metrics import CodePerplexity,CodeCoverage
from data.dataset import get_img_dataset

class TaskSpecificMethods(object):
	def test(self, dataloader, modules, save_path, **kwargs):
		rmse = RootMeanSquaredError(device=self.device)
		# NOTE: FID & InceptionScore uses Inception V3 trained on ImageNet as a feature extractor.
		# NOTE: The inputs are assumed to be standardized with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225],
		#       This is slightly different from the training config.
		fid_feature_extractor = InceptionModel(return_features=True, device=self.device) # NOTE: By default, torch-ignite uses softmax class probs. as features, which are non-standard,
		fid = FID(feature_extractor=fid_feature_extractor, num_features=2048, device=self.device)
		inception_score = InceptionScore(device=self.device) # NOTE: Inception score is fine w/ class probs.
		code_perplexity =CodePerplexity(device=self.device)
		code_coverage = CodeCoverage(device=self.device)
		datasize = len(dataloader.dataset)

		encoder = modules.get('encoder').get('module')
		quantizer = modules.get('quantizer').get('module')
		decoder = modules.get('decoder').get('module')
		for batch in dataloader:
			img,*args,label = batch
			img = img.to(self.device)
			recon_target = img if len(args)==0 else args[0].to(self.device)
			label = label.to(self.device)

			# with torch.autograd.detect_anomaly():
			# with torch.autocast(self.device.type, dtype=torch.bfloat16):
			continuous = encoder(img)
			quantized,code,reg_loss,stats = quantizer(continuous)
			recon = decoder(quantized)
			
			if recon_target.size(-1)<recon.size(-1):
				recon = center_crop(recon, recon_target.size()[-2:])

			rmse.update((recon.view(-1,1), recon_target.view(-1,1))) # NOTE: Pixel-wise

			code_perplexity.update(code)
			code_coverage.update(code)


			fid.update((recon, recon_target))
			inception_score.update(recon)

		os.makedirs(os.path.dirname(save_path), exist_ok=True)
		with open(save_path, 'w') as f:
			json.dump(dict(
					recon_loss=rmse.compute(),
					fid=fid.compute(),
					inception_score=inception_score.compute(),
					code_use_counts=code_perplexity._counts.tolist(),
					code_perplexity_normalized=code_perplexity.compute(),
					code_coverage_normalized=code_coverage.compute(),
					datasize=datasize
				),
				f)


class TaskTester(TaskSpecificMethods, Tester):
	pass



if __name__=='__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('data_name', type=str, choices=['ImageNet','MNIST','CIFAR10'], help='Name of dataset.')
	parser.add_argument('data_root', type=str, help='Path to the directory where data are stored.')
	parser.add_argument('checkpoint_path', type=str, help='Path to the checkpoint file containing trained model parameters.')
	parser.add_argument('save_path', type=str, help='Path to the json file where results are saved.')

	parser.add_argument('--preprocessed', action='store_true', help='Use preprocessed data, saved as .npy files.')

	parser.add_argument('--batch_size', type=int, default=1, help='Batch size.')

	parser.add_argument('--num_workers', type=int, default=0, help='# of dataloading workers.')

	parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda.')
	parser.add_argument('--seed', type=int, default=111, help='Random seed.')
	# parser.add_argument('--ddp', action='store_true', help='Use DistributedDataParallel instead of DataParallel.')

	args = parser.parse_args()

	logger = get_logger()

	logger.info('Test autoencoding on {}.'.format(args.data_name))

	tester = TaskTester(logger, args.checkpoint_path, device=args.device)
	target_transform = None
	if args.preprocessed:
		transform = None
	elif args.data_name=='ImageNet':
		transform = Compose([
			CenterCrop(size=[256,256]),
			ToTensor(),
			Normalize(mean=torch.tensor([0.4815, 0.4578, 0.4082]), std=torch.tensor([0.2686, 0.2613, 0.2758]))
		])
	elif args.data_name=='MNIST':
		transform = Compose([
			ToTensor(),
			Pad((32-28)//2, fill=0, padding_mode='constant'), # Upscale by padding.
			# Normalize(mean=torch.tensor([0.1306]), std=torch.tensor([0.308]))
		])
		target_transform = CenterCrop(size=[28,28]) # Only reconstruct non-padded pixels
	elif args.data_name=='CIFAR10':
		transform = Compose([
			ToTensor(),
			Normalize(mean=torch.tensor([0.491, 0.482, 0.446]), std=torch.tensor([0.247, 0.243, 0.261]))
		])
	dataset = get_img_dataset(args.data_name, args.data_root, split='val',
							transform=transform, target_transform=target_transform,
							preprocessed=args.preprocessed)
	tester(dataset, args.batch_size, args.num_workers, save_path=args.save_path)