from functools import partial
import numpy as np
from six.moves import map, zip
import torch
import random
import glob
import pdb
import os.path as osp
import torch.distributed as dist
import torch.multiprocessing as mp
import logging
from collections import OrderedDict
from mmcv.runner import load_state_dict


def accuracy(output, target, topk=(1, )):
	"""Computes the precision@k for the specified values of k"""
	with torch.no_grad():
		maxk = max(topk)
		batch_size = target.size(0)

		_, pred = output.topk(maxk, 1, True, True)
		pred = pred.t()
		correct = pred.eq(target.view(1, -1).expand_as(pred))

		res = []
		for k in topk:
			correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
			res.append(correct_k.mul_(100.0 / batch_size))
		return res

def multi_apply(func,*args,**kwargs):
	pfunc = partial(func,**kwargs) if kwargs else func
	map_results = map(func,*args)
	return tuple(map(list,zip(*map_results)))

def set_seed(seed,cuda=True):
	random.seed(seed)
	torch.manual_seed(seed)
	np.random.seed(seed)
	if cuda:
		torch.cuda.manual_seed(seed)

def load_weights(model,
				 model_path1,model_path2=None,
				 part1_pre_name='',
				 part2_pre_name='',
				 strict=False,
				 logger=None):
	"""
	Combine the small network and large network weights into the conet.
	"""

	state_dict1 = torch.load(model_path1)
	if model_path2 is not None:
		state_dict2 = torch.load(model_path2)

	## state_dict1
	state_dict = {}
	state_dict1 = {part1_pre_name+k:v for k,v in state_dict1['state_dict'].items()}
	state_dict.update(state_dict1)
	if model_path2 is not None:
		state_dict2 = {part2_pre_name+k:v for k,v in state_dict2['state_dict'].items()}
		state_dict.update(state_dict2)

	load_state_dict(model.module,state_dict,strict, logger)

def filter_dict(dict_file,filter_name=""):
	"""
	This function is used filter the used key-value in dict_file.
	if it has filter_name, keep; else filter.
	"""
	state_dict = OrderedDict()
	for k,v in dict_file.items():
		if filter_name in k:
			state_dict[k]=v
	return state_dict
#
def load_weights_v2(model,
				 model_path1,model_path2=None,
				 part1_pre_name='',
				 part2_pre_name='',
				 part2_filter_name='small_net',
				 strict=False,
				 logger=None):
	"""
	Combine the large network and a small network in a conet weights into the conet.
	"""

	state_dict1 = torch.load(model_path1)
	state_dict2 = torch.load(model_path2)

	## state_dict1
	state_dict = {}
	state_dict1 = {part1_pre_name+k:v for k,v in state_dict1['state_dict'].items()}
	state_dict.update(state_dict1)

	# pdb.set_trace()

	state_dict2_ = dict()
	for k,v in state_dict2['state_dict'].items():
		if part2_filter_name in k:
			# print(part2_pre_name+k)
			state_dict2_[part2_pre_name+k]=v

	state_dict.update(state_dict2_)

	load_state_dict(model.module,state_dict,strict, logger)


##
def get_logger(log_level):
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
    logger = logging.getLogger()
    return logger

def init_dist(backend='nccl', **kwargs):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    rank = int(os.environ['RANK'])
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(rank % num_gpus)
    dist.init_process_group(backend=backend, **kwargs)


def one_hot(x,N):
	"""
	x is [batch]
	"""
	batch = x.shape[0]
	xx = torch.zeros((batch,N))
	for i,j in zip(torch.arange(batch),x):
		xx[i,j]=1
	return xx













###
