import functools
import os
import os.path as osp
from collections import OrderedDict
from math import cos, pi

import torch
from torch import distributed as dist


# Epoch counts from 0 to N-1
def cosine_lr_after_step(optimizer, base_lr, epoch, step_epoch, total_epochs, clip=1e-6):
    if epoch < step_epoch:
        lr = base_lr
    else:
        lr = clip + 0.5 * (base_lr - clip) * \
            (1 + cos(pi * ((epoch - step_epoch) / (total_epochs - step_epoch))))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def is_power2(num):
    return num != 0 and ((num & (num - 1)) == 0)


def is_multiple(num, multiple):
    return num != 0 and num % multiple == 0


def weights_to_cpu(state_dict):
    """Copy a model state_dict to cpu.

    Args:
        state_dict (OrderedDict): Model weights on GPU.
    Returns:
        OrderedDict: Model weights on GPU.
    """
    state_dict_cpu = OrderedDict()
    for key, val in state_dict.items():
        state_dict_cpu[key] = val.cpu()
    return state_dict_cpu


def load_checkpoint(checkpoint, logger, model, optimizer=None, strict=False):  # 加载模型
    if hasattr(model, 'module'):
        model = model.module
    device = torch.cuda.current_device()
    state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage.cuda(device))
    src_state_dict = state_dict['net']
    target_state_dict = model.state_dict()
    skip_keys = []
    # skip mismatch size tensors in case of pretraining
    for k in src_state_dict.keys():
        if k not in target_state_dict:
            continue
        if src_state_dict[k].size() != target_state_dict[k].size():
            skip_keys.append(k)
    for k in skip_keys:
        del src_state_dict[k]
    missing_keys, unexpected_keys = model.load_state_dict(src_state_dict, strict=strict)
    if skip_keys:
        logger.info(
            f'removed keys in source state_dict due to size mismatch: {", ".join(skip_keys)}')
    if missing_keys:
        logger.info(f'missing keys in source state_dict: {", ".join(missing_keys)}')
    if unexpected_keys:
        logger.info(f'unexpected key in source state_dict: {", ".join(unexpected_keys)}')

    # load optimizer
    if optimizer is not None:
        assert 'optimizer' in state_dict
        optimizer.load_state_dict(state_dict['optimizer'])

    if 'epoch' in state_dict:
        epoch = state_dict['epoch']
    else:
        epoch = 0
    return epoch + 1


# def cuda_cast(func):

#     @functools.wraps(func)
#     def wrapper(*args, **kwargs):
#         new_args = []
#         for x in args:
#             if isinstance(x, torch.Tensor):
#                 x = x.cuda()
#             new_args.append(x)
#         new_kwargs = {}
#         for k, v in kwargs.items():
#             if isinstance(v, torch.Tensor):
#                 v = v.cuda()
#             new_kwargs[k] = v
#         return func(*new_args, **new_kwargs)

#     return wrapper

def cuda_cast(func):

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        new_args = []
        for x in args:
            if hasattr(x, 'cuda'):
                x = x.cuda()
            else:
                for key in x.keys():
                    if isinstance(x[key], torch.Tensor):
                        x[key] = x[key].cuda()
            new_args.append(x)
        new_kwargs = {}
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.cuda()
            new_kwargs[k] = v
        return func(*new_args, **new_kwargs)

    return wrapper

# def cuda_cast(func):

#     @functools.wraps(func)
#     def wrapper(*args, **kwargs):
#         new_args = []
#         for x in args:
#             if hasattr(x, 'cuda'):
#                 x = x.cuda()
#                 new_args.append(x)
#             else:
#                 for key in x.keys():
#                     x[key].cuda()
#             # new_args.append(x)
#         new_kwargs = {}
#         for k, v in kwargs.items():
#             if hasattr(v, 'cuda'):
#                 v = v.cuda()
#             elif isinstance(v, list) and hasattr(v[0], 'cuda'):
#                 v = [x.cuda() for x in v]
#             new_kwargs[k] = v
#         return func(*new_args, **new_kwargs)

#     return wrapper
