# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------

import os
import torch
import torch.distributed as dist

try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None

def batch_index_select(x, idx):
    if len(x.size()) == 3:
        B, N, C = x.size()
        N_new = idx.size(1)
        offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
        idx = idx + offset
        out = x.reshape(B*N, C)[idx.reshape(-1)].reshape(B, N_new, C)
        return out
    elif len(x.size()) == 2:
        B, N = x.size()
        N_new = idx.size(1)
        offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
        idx = idx + offset
        out = x.reshape(B*N)[idx.reshape(-1)].reshape(B, N_new)
        return out
    else:
        raise NotImplementedError

def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
    logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
    if config.MODEL.RESUME.startswith('https'):
        checkpoint = torch.hub.load_state_dict_from_url(
            config.MODEL.RESUME, map_location='cpu', check_hash=True)
    else:
        checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    logger.info(msg)
    max_accuracy = 0.0
    if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        config.defrost()
        config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
        config.freeze()
        if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
            amp.load_state_dict(checkpoint['amp'])
        logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
        if 'max_accuracy' in checkpoint:
            max_accuracy = checkpoint['max_accuracy']

    del checkpoint
    torch.cuda.empty_cache()
    return max_accuracy

def load_checkpoint_ft(pretrain_path, model, optimizer, lr_scheduler, logger):
    logger.info(f"==============> Loading pretrained model form {pretrain_path}....................")
    state_dict = torch.load(pretrain_path, map_location='cpu')['model']
    own_state_dict = model.state_dict()
    for name, param in state_dict.items():
        if name in own_state_dict and "head" not in name:
            own_state_dict[name].copy_(param)
    logger.info(f"=> loaded successfully")
    max_accuracy = 0.0
    del state_dict
    torch.cuda.empty_cache()
    return max_accuracy

def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
    save_state = {'model': model.state_dict(),
                  'optimizer': optimizer.state_dict(),
                  'lr_scheduler': lr_scheduler.state_dict(),
                  'max_accuracy': max_accuracy,
                  'epoch': epoch,
                  'config': config}
    if config.AMP_OPT_LEVEL != "O0":
        save_state['amp'] = amp.state_dict()
    
    save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
    logger.info(f"{save_path} saving......")
    torch.save(save_state, save_path)
    logger.info(f"{save_path} saved !!!")

def save_checkpoint_best(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
    save_state = {'model': model.state_dict(),
                  'max_accuracy': max_accuracy,
                  'epoch': epoch,
                  'config': config}
    if config.AMP_OPT_LEVEL != "O0":
        save_state['amp'] = amp.state_dict()
    
    save_path = os.path.join(config.OUTPUT, 'best_model.pth')
    logger.info(f"{save_path} saving......")
    torch.save(save_state, save_path)
    logger.info(f"{save_path} saved !!!")

def get_grad_norm(parameters, norm_type=2):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    norm_type = float(norm_type)
    total_norm = 0
    for p in parameters:
        param_norm = p.grad.data.norm(norm_type)
        total_norm += param_norm.item() ** norm_type
    total_norm = total_norm ** (1. / norm_type)
    return total_norm

def auto_resume_helper(output_dir):
    checkpoints = os.listdir(output_dir)
    checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
    print(f"All checkpoints founded in {output_dir}: {checkpoints}")
    if len(checkpoints) > 0:
        latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
        print(f"The latest checkpoint founded: {latest_checkpoint}")
        resume_file = latest_checkpoint
    else:
        resume_file = None
    return resume_file


def reduce_tensor(tensor):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt


import torch
from torch import nn


@torch.no_grad()
def get_activation(preact_dict, param_name, hook_type):
    """
    Hooks used for in sensitivity schedulers (LOBSTE, Neuron-LOBSTER, SERENE).
    :param preact_dict: Dictionary in which save the parameters information.
    :param param_name: Name of the layer, used a dictionary key.
    :param hook_type: Hook type.
    :return: Returns a forward_hook if $hook_type$ is forward, else a backward_hook.
    """

    def forward_hook(model, inp, output):
        preact_dict[param_name] = output

    def backward_hook(module, grad_input, grad_output):
        preact_dict[param_name] = None
        preact_dict[param_name] = grad_output[0].detach()

    return forward_hook if hook_type == "forward" else backward_hook


@torch.no_grad()
def apply_mask_params(model, mask):
    """
    Element-wise multiplication between a tensor and the corresponding mask.
    :param mask: Dictionary containing the tensor mask at the given key.
    """
    for n_m, mo in model.named_modules():
        for n_p, p in mo.named_parameters():
            name = "{}.{}".format(n_m, n_p)
            p.mul_(mask[name])


@torch.no_grad()
def apply_mask_neurons(model, mask):
    """
    Element-wise multiplication between a tensor and the corresponding mask.
    :param mask: Dictionary containing the tensor mask at the given key.
    """
    for n_m, mo in model.named_modules():
        if isinstance(mo, (nn.modules.Linear, nn.modules.Conv2d, nn.modules.ConvTranspose2d, nn.modules.BatchNorm2d)):
            for n_p, p in mo.named_parameters():
                name = "{}.{}".format(n_m, n_p)
                if len(p.shape) == 1:
                    p.mul_(mask[name])
                elif len(p.shape) == 2:
                    p.copy_(torch.einsum(
                        'ij,i->ij',
                        p,
                        mask[name]
                    ))
                elif len(p.shape) == 4:
                    if isinstance(mo, nn.modules.Conv2d):
                        p.copy_(torch.einsum(
                            'ijnm,i->ijnm',
                            p,
                            mask[name]
                        ))

                    if isinstance(mo, nn.modules.ConvTranspose2d):
                        p.copy_(torch.einsum(
                            'ijnm,j->ijnm',
                            p,
                            mask[name]
                        ))


@torch.no_grad()
def substitute_module(model, new_module, sub_module_names):
    """
    Substitute a nn.module in a given PyTorch model with another.
    :param model: PyTorch model on which the substitution occurs.
    :param new_module: New module to insert in the model.
    :param sub_module_names: List of string representing the old module name.
    i.e if the module name is layer1.0.conv1 `sub_module_names` should be ["layer1", "0", "conv1"]
    """
    if new_module is not None:
        attr = model
        for idx, sub in enumerate(sub_module_names):
            if idx < len(sub_module_names) - 1:
                attr = getattr(attr, sub)
            else:
                setattr(attr, sub, new_module)


@torch.no_grad()
def find_module(model, name):
    """
    Find a module in the given model by name.
    :param model: PyTorch model.
    :param name: Module name
    :return: The searched module and the following.
    """
    found_module = False
    current_module = None
    next_module = None
    for module_name, module in model.named_modules():
        if len(list(module.children())) == 0:
            dict_name = "{}.weight".format(module_name)
            if name in dict_name:
                current_module = module
                found_module = True
                continue
            if found_module and not isinstance(module, nn.Identity):
                next_module = module
                break

    return current_module, next_module


@torch.no_grad()
def get_model_mask_neurons(model, layers):
    """
    Defines a dictionary of type {layer: tensor} containing for each layer of a model, the binary mask representing
    which neurons have a value of zero (all of its parameters are zero).
    :param model: PyTorch model.
    :param layers: Tuple of layers on which apply the threshold procedure. e.g. (nn.modules.Conv2d, nn.modules.Linear)
    :return: Mask dictionary.
    """
    mask = {}
    for n_m, mo in model.named_modules():
        if isinstance(mo, layers):
            for n_p, p in mo.named_parameters():
                name = "{}.{}".format(n_m, n_p)

                if "weight" in n_p:
                    if isinstance(mo, nn.modules.Linear):
                        sum = torch.abs(p).sum(dim=1)
                        mask[name] = torch.where(sum == 0, torch.zeros_like(sum), torch.ones_like(sum))
                    elif isinstance(mo, nn.modules.Conv2d):
                        sum = torch.abs(p).sum(dim=(1, 2, 3))
                        mask[name] = torch.where(sum == 0, torch.zeros_like(sum), torch.ones_like(sum))
                    elif isinstance(mo, nn.modules.ConvTranspose2d):
                        sum = torch.abs(p).sum(dim=(0, 2, 3))
                        mask[name] = torch.where(sum == 0, torch.zeros_like(sum), torch.ones_like(sum))
                    else:
                        mask[name] = torch.where(p == 0, torch.zeros_like(p), torch.ones_like(p))
                else:
                    mask[name] = torch.where(p == 0, torch.zeros_like(p), torch.ones_like(p))

    return mask


def get_tensor_mask_neurons(tensor):
    return (torch.sum(torch.abs(tensor), dim=1) == 0).nonzero(as_tuple=False).flatten()


@torch.no_grad()
def get_model_mask_parameters(model, layers):
    """
    Defines a dictionary of type {layer: tensor} containing for each layer of a model, the binary mask representing
    which parameters have a value of zero.
    :param model: PyTorch model.
    :param layers: Tuple of layers on which apply the threshold procedure. e.g. (nn.modules.Conv2d, nn.modules.Linear)
    :return: Mask dictionary.
    """
    mask = {}
    for n_m, mo in model.named_modules():
        if isinstance(mo, layers):
            for n_p, p in mo.named_parameters():
                name = "{}.{}".format(n_m, n_p)
                mask[name] = torch.where(p == 0, torch.zeros_like(p), torch.ones_like(p))

    return mask


@torch.no_grad()
def magnitude_threshold(model, layers, T, ):
    """
    Performs magnitude thresholding on a network, all the elements of the tensor below a threshold are zeroed.
    :param model: PyTorch model on which apply the thresholding, layer by layer.
    :param layers: Tuple of layers on which apply the threshold procedure. e.g. (nn.modules.Conv2d, nn.modules.Linear)
    :param T: Threhsold value.
    """
    for n_m, mo in model.named_modules():
        if isinstance(mo, layers):
            for n_p, p in mo.named_parameters():
                p.copy_(torch.where(torch.abs(p) < T, torch.zeros_like(p), p))


@torch.no_grad()
def sensitivity_threshold(model, layers, T, sensitivity, layer_name, bn_prune):
    """
    Performs magnitude thresholding on a network, all the elements of the tensor below a threshold are zeroed.
    :param model: PyTorch model on which apply the thresholding, layer by layer.
    :param layers: Tuple of layers on which apply the threshold procedure. e.g. (nn.modules.Conv2d, nn.modules.Linear)
    :param T: Threhsold value.
    """
    modules = list(reversed(list(model.named_modules())[:-1]))
    for i, (nm, mo) in enumerate(modules):
        if layer_name is not None:
            if nm == layer_name:
                if isinstance(mo, layers):
                    s = sensitivity[nm]
                    prune_mask = torch.where(s < T, torch.zeros_like(s), torch.ones_like(s))
                    for n_p, p in mo.named_parameters():

                        if prune_mask.device != p.device:
                            prune_mask = prune_mask.to(p.device)

                        if "weight" in n_p:
                            if isinstance(mo, nn.modules.Linear):
                                p.copy_(torch.einsum(
                                    'ij,i->ij',
                                    p,
                                    prune_mask
                                ))
                            elif isinstance(mo, nn.modules.Conv2d):
                                p.copy_(torch.einsum(
                                    'ijnm,i->ijnm',
                                    p,
                                    prune_mask
                                ))
                            elif isinstance(mo, nn.modules.ConvTranspose2d):
                                p.copy_(torch.einsum(
                                    'ijnm,j->ijnm',
                                    p,
                                    prune_mask
                                ))
                            else:
                                p.copy_(torch.mul(p, prune_mask))

                            # Bias
                        else:
                            p.copy_(torch.mul(p, prune_mask))

                    return prune_mask
        else:
            if isinstance(mo, layers):
                if bn_prune and isinstance(mo, (nn.modules.Conv2d, nn.modules.ConvTranspose2d)):
                    if isinstance(modules[i - 1][1], nn.modules.BatchNorm2d):
                        for n_p, p in mo.named_parameters():
                            if "weight" in n_p:
                                if prune_mask.shape[0] != 1:
                                    if isinstance(mo, nn.modules.Conv2d):
                                        p.copy_(torch.einsum(
                                            'ijnm,i->ijnm',
                                            p,
                                            prune_mask))
                                    if isinstance(mo, nn.modules.ConvTranspose2d):
                                        p.copy_(torch.einsum(
                                            'ijnm,j->ijnm',
                                            p,
                                            prune_mask))
                                else:
                                    p.copy_(torch.mul(p, prune_mask))
                            else:
                                p.copy_(torch.mul(p, prune_mask))

                        continue

                s = sensitivity[nm]
                prune_mask = torch.where(s < T, torch.zeros_like(s), torch.ones_like(s))
                for n_p, p in mo.named_parameters():

                    if prune_mask.device != p.device:
                        prune_mask = prune_mask.to(p.device)

                    if "weight" in n_p:
                        if isinstance(mo, nn.modules.Linear):
                            p.copy_(torch.einsum(
                                'ij,i->ij',
                                p,
                                prune_mask
                            ))
                        elif isinstance(mo, nn.modules.Conv2d):
                            p.copy_(torch.einsum(
                                'ijnm,i->ijnm',
                                p,
                                prune_mask
                            ))
                        elif isinstance(mo, nn.modules.ConvTranspose2d):
                            p.copy_(torch.einsum(
                                'ijnm,j->ijnm',
                                p,
                                prune_mask
                            ))
                        else:
                            p.copy_(torch.mul(p, prune_mask))

                        # Bias
                    else:
                        p.copy_(torch.mul(p, prune_mask))
