# coding=utf-8
import os
import errno
import torch
import shutil
import datetime


def save_checkpoint(state, is_best, model_dir, filename='checkpoint.pth.tar'):
	filename = os.path.join(model_dir, filename)
	torch.save(state, filename)
	if is_best:
		shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))


def get_the_number_of_params(model, is_trainable=False):
	"""get the number of the model"""
	if is_trainable:
		return sum(p.numel() for p in model.parameters() if p.requires_grad)
	return sum(p.numel() for p in model.parameters())


def mkdir(path):
	try:
		os.makedirs(path)
	except OSError as e:
		if e.errno != errno.EEXIST:
			raise


def remote_download_dataset(args=None):
	"""download remote dataset"""
	if not args.remote and not args.remote_wise:
		return


	if args.remote_wise and not args.data_loaded:
		import moxing as mox

		time_s = datetime.datetime.now()
		print("Copying data from s3 to remote server >>>...")
		mox.file.set_auth(ak='YVNIJMRIH8ZEFYTBFPPZ',
							sk='nkWgmVQSY14qTiuPGyQhkfuJwsUKS9VYw0ROUYNk')
		mox.file.copy_parallel(args.data_url, args.remote_data_dir)
		print('time cost to copy {}::{:.2f} m'.format(args.data_url, (datetime.datetime.now() - time_s).seconds / 60))
		print('The content of remote_data_dir {} is...'.format(args.remote_data_dir))
		os.system('ls -l {}'.format(args.remote_data_dir))

		if os.path.exists(args.dali_data_dir):
			os.system('pip install {}/nvidia_dali_cuda102-1.11.1-4069476-py3-none-manylinux2014_x86_64.whl'.format(args.remote_data_dir))		

		return
	
	if args.remote and not args.data_loaded:
		import moxing as mox

		time_s = datetime.datetime.now()
		print("Copying data from s3 to remote server >>>...")
		mox.file.copy_parallel(args.data_url, args.remote_data_dir)
		print('time cost to copy {}::{:.2f} m'.format(args.data_url, (datetime.datetime.now() - time_s).seconds / 60))
		print('The content of remote_data_dir {} is...'.format(args.remote_data_dir))
		os.system('ls -l {}'.format(args.remote_data_dir))
		# print('The content of pretrained_dir {} is...'.format(args.pretrained_dir))
		# os.system('ls -l {}'.format(args.pretrained_dir))
		if os.path.exists(args.dali_data_dir):
			os.system('pip install {}/nvidia_dali_cuda102-1.11.1-4069476-py3-none-manylinux2014_x86_64.whl'.format(args.remote_data_dir))

		return


def remote_copy_files(args=None, tar_filename='output.tar.gz', copy_choice=1,
						remove_exist=False):
	if not args.remote and not args.remote_wise and not args.remote_du:
		return

	if args.local_rank == 0 and args.remote_du:
		print('=' * 50)
		os.chdir(args.log_dir)
		os.chdir('..')
		base_name = os.path.basename(args.log_dir)
		tar_filename = '{}.tar'.format(base_name)
		if os.path.exists(os.path.join(os.getcwd(), tar_filename)):
			print('remove existing tarfile and create a new tarfile')
			os.system("rm {}".format(tar_filename))

		os.system("tar -zcvf {} {}".format(tar_filename, base_name))
		os.system("rsync -rvP {} {}".format(tar_filename, args.output_url))
		os.system("ls -lah {}".format(args.output_url))
		print("Copy success!")

		return


	import moxing as mox
	# data copy code for remote
	if args.local_rank == 0 and args.remote:
		print('=' * 50)
		os.chdir(args.log_dir)
		os.system("tar -zcvf ./{}  ./*".format(tar_filename))
		mox.file.copy_parallel(args.log_dir, args.train_url)
		if len(mox.file.list_directory(args.log_dir)) > 0:
			files = mox.file.list_directory(args.log_dir)
			if len(files) == 0:
				print("train_url has no file!")
				mox.file.copy_parallel(args.log_dir, args.train_url)
			else:
				print("Copy really success!")
				exit()
		else:
			print('no file to copy in {}'.format(args.log_dir))	


	if args.local_rank == 0 and args.remote_wise:
		print('=' * 50)
		time_s = datetime.datetime.now()
		mox.file.set_auth(ak='YVNIJMRIH8ZEFYTBFPPZ',
							sk='nkWgmVQSY14qTiuPGyQhkfuJwsUKS9VYw0ROUYNk')
		# output_dir = args.output_url + '/' + '{}'.format(os.path.basename(args.log_dir))
		output_dir = os.path.join(args.output_url, '{}'.format(os.path.basename(args.log_dir)))
		if mox.file.exists(output_dir):
			if remove_exist:
				print("remove existed files...")
				mox.file.remove(output_dir, recursive=True)
			else:
				pass

		mox.file.make_dirs(output_dir)
		mox.file.copy_parallel(args.log_dir, output_dir)

		print('time cost to copy to dir {}::{:.2f} m'.format(output_dir, (datetime.datetime.now() - time_s).seconds / 60))
		print('The content of remote_data_dir is...')
		for file in mox.file.list_directory(output_dir):
			print(file)
