# coding: utf-8

import os.path,copy,subprocess,datetime
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# from timm.scheduler import CosineLRScheduler
from .scheduler import WarmupCosineFlatSchedule
import model as M
from data.dataloader import get_data_loader


class DDPLearner(object):
	def __init__(self, logger, save_dir, model_configs, loss_weights, device='cpu', seed=111):
		self.logger = logger
		# self.device = torch.device(device)
		assert device.startswith('cuda') and torch.cuda.is_available(), 'DDP is only available for CUDA envs.'
		self.slurm = 'SLURM_NTASKS' in os.environ # Check if using Slurm
		if self.slurm:
			# NOTE: Using multiprocessing with Slurm raises some error.
			# See https://discuss.pytorch.org/t/training-on-gpus-from-runtimeerror-address-already-in-use-to-timeout/172460/3
			self.world_size = int(os.environ['SLURM_NTASKS']) # NOTE: != SLURM_NTASKS_PER_NODE.
			self.rank = int(os.environ["SLURM_PROCID"])
			self.local_rank = int(os.environ['SLURM_LOCALID'])
			self.log_once('SLURM_NTASKS:{}'.format(self.world_size))
			self.log_once('Using SLURM.')
			# self._ddp_setup()
		else:
			self.world_size = torch.cuda.device_count()
			self.rank = None
			self.local_rank =  None

		self.log_once("PyTorch ver.: {ver}".format(ver=torch.__version__))
		# self.logger.info('Device: {device}'.format(device=device))
		self.log_once('CUDA Version: {version}'.format(version=torch.version.cuda))
		# if self.slurm:
			# dist.barrier()
		if torch.backends.cudnn.enabled:
			self.log_once('cuDNN Version: {version}'.format(version=torch.backends.cudnn.version()))
		if self.slurm:
			self.logger.info('CUDA Device #{rank}: {device_name}'.format(rank=self.rank, device_name=torch.cuda.get_device_name(self.local_rank)))
		else:
			for device_idx in range(torch.cuda.device_count()):
				self.logger.info('CUDA Device #{device_idx}: {device_name}'.format(device_idx=device_idx, device_name=torch.cuda.get_device_name(device_idx)))
		self.retrieval = os.path.isfile(os.path.join(save_dir, 'checkpoint.pt'))
		self.save_dir = save_dir

		# torch.backends.cudnn.deterministic = True
		# torch.backends.cudnn.benchmark = False

		if self.retrieval:
			self.load_checkpoint()
			self.log_once('Checkpoint loaded.')
			self.seed = self.checkpoint['random_seed']
		else:
			self.seed = seed
			self.log_once('Random seed: {seed}'.format(seed = seed))
			self.checkpoint = dict(modules=dict(), random_seed=seed)
			for module_key,kwargs in model_configs.items():
				self.log_module_info(module_key, **kwargs)
				self.checkpoint['modules'][module_key] = kwargs
			self.checkpoint['last_epoch'] = 0
			self.checkpoint['loss_weights'] = loss_weights
		# if self.slurm:
			# dist.barrier()

	def is_record_proc(self, rank=0):
		return ((not self.slurm) and rank==0) or (self.rank==0 and self.local_rank==0)

	def log_once(self, message, type='info', **kwargs):
		if self.is_record_proc(**kwargs):
			getattr(self.logger, type)(message)

	def _ddp_setup(self, rank):
		if self.slurm: # When not using slurm.
			rank = self.rank # NOTE: dist.init_process_group() should be called with NODE-LEVEL rank, rather than the local (GPU-level) rank.
			os.environ['MASTER_ADDR'] \
				= subprocess.check_output(
							["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]
							).split()[0].decode("utf-8")
			os.environ['MASTER_PORT'] = str('1'+os.environ['SLURM_JOB_ID'][-4:])
		else:
			os.environ['MASTER_ADDR'] = 'localhost'
			os.environ['MASTER_PORT'] = '12355'

		# initialize the process group
		dist.init_process_group("nccl", rank=rank, world_size=self.world_size,timeout=datetime.timedelta(minutes=30))
		# dist.init_process_group("gloo", rank=rank, world_size=self.world_size)
	
	def train(self, rank, dataloader, num_epochs, last_epoch):
		torch.backends.cudnn.deterministic = True
		torch.backends.cudnn.benchmark = False
		torch.manual_seed(self.seed)
		torch.cuda.manual_seed(self.seed) # Device-specific seediing.

		modules = {module_key:self.build_module(rank=rank, **kwargs) for module_key,kwargs in self.checkpoint['modules'].items()}
		[module['module'].train() for module in modules.values()]
		device = torch.device(rank)

		records = dict()
		iteration = 0
		for epoch in range(last_epoch+1,num_epochs+1):
			for batch in dataloader:
				iteration += 1
				# NOTE: Define device-dependent random seed.
				seed = self.seed + iteration*self.world_size + rank
				torch.manual_seed(seed)
				torch.cuda.manual_seed(seed) # Device-specific seediing.

				[module['optimizer'].zero_grad()
					for module in modules.values()
					if not module['optimizer'] is None]
				self.train_per_iteration(batch, records, modules=modules, device=device, rank=rank)

			if self.is_record_proc(rank=rank):
				self.logger.info('{epoch}/{num_epochs} epochs complete.'.format(epoch=epoch, num_epochs=num_epochs))
				self.log_training_stats(records, len(dataloader))
				self.save_model(modules, epoch) # Back to the original numbering starting with 0.
				records = dict()
			# dist.barrier()
		if self.is_record_proc(rank=rank):
			self.save_model(modules, epoch)
			# self.logger.info('(Re)saved the final checkpoint after {iteration} iterations.'.format(iteration=iteration))
		# dist.barrier()

	def __call__(self, *args):
		if self.slurm:
			self._call_in_subproc(None, *args)
		else:
			import torch.multiprocessing as mp
			mp.spawn(self._call_in_subproc,
				args=args,
				nprocs=self.world_size,
				join=True
				)

	def _call_in_subproc(self, rank, dataset, num_epochs, global_batch_size, global_num_workers):
		self._ddp_setup(rank)
		if self.slurm:
			rank = self.local_rank # NOTE: Everything below is performed on the local (GPU-level) rank.
		last_epoch = self.checkpoint['last_epoch']
		assert global_batch_size%self.world_size==0, 'global_batch_size must be divisible by world_size={}'.format(self.world_size)
		batch_size = global_batch_size//self.world_size
		assert global_num_workers%self.world_size==0, 'global_num_workers must be divisible by world_size={}'.format(self.world_size)
		num_workers = global_num_workers//self.world_size
		if last_epoch>=num_epochs:
			self.log_once('Training is already completed.', rank=rank)
			return
		else:
			self.log_once("START LEARNING.", rank=rank)
			self.log_once("# of epochs: {ep}".format(ep=num_epochs), rank=rank)
			self.log_once("batch size for training data: {size}".format(size=global_batch_size), rank=rank)
		self.logger.info("sub-batch size for GPU #{rank}: {size}".format(size=batch_size, rank=self.rank if self.slurm else rank))
		dataloader = get_data_loader(
								dataset,
								batch_size=batch_size,
								num_workers=num_workers,
								random_seed=self.seed,
								ddp=True,
								rank=rank,)
		self.train(rank, dataloader, num_epochs, last_epoch)
		self.log_once('END OF TRAINING')
		# dist.barrier()
		# dist.destroy_process_group()


	def load_checkpoint(self, checkpoint_path = None):
		if checkpoint_path is None:
			checkpoint_path = os.path.join(self.save_dir, 'checkpoint.pt')
		self.checkpoint = torch.load(checkpoint_path, map_location='cpu') # Random state needs to be loaded to CPU first even when cuda is available.

	def update_records(self, records, name, value):
		records[name] = records.get(name, 0.0) + value

	def build_module(self, rank, module_name, init_args, optim_config, scheduler_config,
						state_dicts=None, finetune=False):
		# self.logger.info('{module_name} instantiated as "{module_key}" w/ following parameters: {init_args}'.format(
							# module_name=module_name, module_key=module_key, init_args=str(init_args)))
		if self.slurm:
			rank = self.local_rank # Overwrite
		init_args = copy.deepcopy(init_args) # NOTE: Some module may pop dictionary arguments at instantiation, which synchronyously modifies the init_args in checkpoint.
		module = getattr(M, module_name)(**init_args)
		if not state_dicts is None:
			try:
				module.load_state_dict(state_dicts['module'], strict=True)
			except Exception as e:
				self.log_once(e, rank=rank)
				module.load_state_dict(state_dicts['module'], strict=False)
			if finetune:
				for name,param in module.named_parameters():
					if name in state_dicts['module']:
						param.requires_grad = False
			self.log_once('state_dict loaded on {}.'.format(module_name), rank=rank)
		module = module.to(rank)
		num_trainable_params = sum([param.numel() for param in module.parameters()
													if param.requires_grad])
		if num_trainable_params:
			module = DDP(module, device_ids=[rank]) # NOTE: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
			# NOTE: Optimizers skip parameters w/ require_grad=False. https://discuss.pytorch.org/t/passing-to-the-optimizers-frozen-parameters/83358/7
			optimizer = torch.optim.AdamW(module.parameters(), **optim_config)
			scheduler = WarmupCosineFlatSchedule(optimizer, **scheduler_config)
			if not state_dicts is None:
				if 'optimizer' in state_dicts:
					optimizer.load_state_dict(state_dicts['optimizer'])
				if 'scheduler' in state_dicts:
					scheduler.load_state_dict(state_dicts['scheduler'])
		else: # NOTE: No optimizer for parameter-less modules (e.g. loss functions).
			optimizer = None
			scheduler = None
		return dict(module=module,optimizer=optimizer, scheduler=scheduler)
	
	def log_module_info(self, module_key, module_name, init_args, optim_config, scheduler_config, **kwargs):
		self.log_once('{module_name} instantiated as "{module_key}" w/ following parameters: {init_args}'.format(
							module_name=module_name, module_key=module_key, init_args=str(init_args)))
		self.log_once('(Submodules of) {module_key} is/are trained by Adam optimizer w/ following parameters: {optim_config}'.format(
			module_key=module_key, optim_config=str(optim_config)))
		self.log_once('Learning rate for (submodules of) {module_key} is scheduled by CosineLRScheduler w/ following parameters: {scheduler_config}'.format(
			module_key=module_key, scheduler_config=str(scheduler_config)))

	def update_params(self, name, modules):
		module = modules[name]
		# NOTE: Using bfloat16 doesn't require grad scaling. https://discuss.pytorch.org/t/why-bf16-do-not-need-loss-scaling/176596/3
		clip_grad_norm_(module['module'].parameters(), 1.0)
		module['optimizer'].step()
		module['scheduler'].step()

	def save_model(self, modules, epoch):
		"""
		Save model config.
		"""
		for module_key,info in self.checkpoint['modules'].items():
			module = modules[module_key]
			if not module['optimizer'] is None: # NOTE: Nothing to save for loss function etc.
				info['state_dicts'] = dict(module=module['module'].module.state_dict(),
											optimizer=module['optimizer'].state_dict(),
											scheduler=module['scheduler'].state_dict())
		self.checkpoint['last_epoch'] = epoch
		torch.save(self.checkpoint, os.path.join(self.save_dir, 'checkpoint_after-{epoch}-epochs.pt'.format(epoch=epoch)))
		torch.save(self.checkpoint, os.path.join(self.save_dir, 'checkpoint.pt'))
		self.logger.info('Config successfully saved.')

class Learner(DDPLearner):
	slurm = False
	world_size = 1
	rank = 0
	local_rank = 0
	def __init__(self, logger, save_dir, model_configs, loss_weights, device='cpu', seed=111):
		self.logger = logger
		self.device = torch.device(device)
		self.log_once('Using DataParallel.')
		self.log_once("PyTorch ver.: {ver}".format(ver=torch.__version__))
		self.log_once('CUDA Version: {version}'.format(version=torch.version.cuda))
		if torch.backends.cudnn.enabled:
			self.log_once('cuDNN Version: {version}'.format(version=torch.backends.cudnn.version()))
		for device_idx in range(torch.cuda.device_count()):
			self.logger.info('CUDA Device #{device_idx}: {device_name}'.format(device_idx=device_idx, device_name=torch.cuda.get_device_name(device_idx)))
		self.retrieval = os.path.isfile(os.path.join(save_dir, 'checkpoint.pt'))
		self.save_dir = save_dir

		# torch.backends.cudnn.deterministic = True
		# torch.backends.cudnn.benchmark = False

		if self.retrieval:
			self.load_checkpoint()
			self.log_once('Checkpoint loaded.')
		else:
			self.seed = seed
			self.log_once('Random seed: {seed}'.format(seed = seed))
			self.checkpoint = dict(modules=dict(), random_seed=seed)
			for module_key,kwargs in model_configs.items():
				self.log_module_info(module_key, **kwargs)
				self.checkpoint['modules'][module_key] = kwargs
			self.checkpoint['last_epoch'] = 0
			self.checkpoint['loss_weights'] = loss_weights

	def build_module(self, rank, module_name, init_args, optim_config, scheduler_config,
						state_dicts=None, finetune=False):
		init_args = copy.deepcopy(init_args) # NOTE: Some module may pop dictionary arguments at instantiation, which synchronyously modifies the init_args in checkpoint.
		module = getattr(M, module_name)(**init_args).to(self.device)
		if not state_dicts is None:
			module.load_state_dict(state_dicts['module'], strict=False)
			if finetune:
				for name,param in module.named_parameters():
					if name in state_dicts['module']:
						param.requires_grad = False
			# self.logger.info('state_dict loaded on {}.'.format(module_name))
		module = nn.DataParallel(module)
		optimizer = torch.optim.AdamW(module.parameters(), **optim_config)
		scheduler = WarmupCosineFlatSchedule(optimizer, **scheduler_config)
		if not state_dicts is None:
			if 'optimizer' in state_dicts:
				optimizer.load_state_dict(state_dicts['optimizer'])
			if 'scheduler' in state_dicts:
				scheduler.load_state_dict(state_dicts['scheduler'])
		return dict(module=module,optimizer=optimizer, scheduler=scheduler)
	
	def __call__(self, dataset, num_epochs, batch_size, num_workers):
		last_epoch = self.checkpoint['last_epoch']
		if last_epoch>=num_epochs:
			self.log_once('Training is already completed.')
			return
		else:
			self.log_once("START LEARNING.")
			self.log_once("# of epochs: {ep}".format(ep=num_epochs))
			self.log_once("batch size for training data: {size}".format(size=batch_size))
		dataloader = get_data_loader(
								dataset,
								batch_size=batch_size,
								num_workers=num_workers,
								random_seed=self.seed,
								ddp=False)
		self.train(0, dataloader, num_epochs, last_epoch)
		self.log_once('END OF TRAINING')

class Tester(Learner):
	# NOTE: Testing is performed using nn.DataParallel rather than DDP.
	def __init__(self, logger, checkpoint_path, device='cpu'):
		self.logger = logger
		self.device = torch.device(device)
		self.load_checkpoint(checkpoint_path=checkpoint_path)
		self.device = torch.device(device)
		
	def build_module(self, module_name, init_args, state_dicts=None, **kwargs):
		init_args = copy.deepcopy(init_args) # NOTE: Some module may pop dictionary arguments at instantiation, which synchronyously modifies the init_args in checkpoint.
		module = getattr(M, module_name)(**init_args).to(self.device)
		if not state_dicts is None:
			try:
				module.load_state_dict(state_dicts['module'], strict=True)
			except Exception as e:
				self.logger.info('{}'.format(e))
				self.logger.info('Loads state_dict with strict=False.')
				module.load_state_dict(state_dicts['module'], strict=False)
		module = nn.DataParallel(module)
		return dict(module=module)

	def is_record_proc(self, rank=0):
		return True

	def __call__(self, dataset, batch_size, num_workers=1, **kwargs):
		dataloader = get_data_loader(dataset,
									batch_size,
									shuffle=False,
									num_workers=num_workers)
		modules = {module_key:self.build_module(**kwargs) for module_key,kwargs in self.checkpoint['modules'].items()}
		[module['module'].eval() for module in modules.values()]
		with torch.no_grad():
			self.test(dataloader, modules, **kwargs)

	def test(dataloader, **kwargs):
		raise NotImplementedError