from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import copy
import re
import torch
import torchvision


def load_model(model_name, path=None, GPU=False):
    """
    Retrieves a model either from
    disk or by using torchvision.
    """
    # Retrieve the model from TorchVision
    if path is None:
        model = torchvision.models.__dict__[model_name](pretrained=True)
    else:
        # Load local model
        if GPU:
            checkpoint = torch.load(path)
        else:
            checkpoint = torch.load(path, map_location=torch.device('cpu'))

        # Check if state_dict was serialized
        if isinstance(checkpoint, dict):
            # Retrieve state_dict
            if 'state_dict' in checkpoint:
                state_dict = {str.replace(
                    k, 'module.', ''): v for k, v
                    in checkpoint['state_dict'].items()}
            else:
                state_dict = checkpoint

            # Get last layer
            last_layer = list(state_dict)[-1]
            num_classes = state_dict[last_layer].size()[0]

            # Retrieve architecture
            model = torchvision.models.__dict__[
                model_name](num_classes=num_classes)

            try:
                model.load_state_dict(state_dict)
            except RuntimeError as e:
                # FIXME: DenseNet-161 keys might differ
                if model_name == 'densenet161':
                    pattern = re.compile(
                      r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.'
                      r'(?:weight|bias|running_mean|running_var))$')
                    for key in list(state_dict.keys()):
                        res = pattern.match(key)
                        if res:
                            new_key = res.group(1) + res.group(2)
                            state_dict[new_key] = state_dict[key]
                            del state_dict[key]
                    model.load_state_dict(state_dict)
                else:
                    raise RuntimeError(e)

        else:
            model = checkpoint

    # Finalize model loading
    if GPU:
        model.cuda()

    return model


def get_names(model):
    """
    Returns the list of names
    for the hookable modules
    in the given model.
    """
    return [e[0] for e in model.named_modules() if e[0]]


def get_module(model, name):
    """
    Selects a module by name
    """
    tokens = [e for e in name.split('.') if e]
    module = model
    for token in tokens:
        if isinstance(module, torch.nn.Sequential):
            try:
                idx = int(token)
            except ValueError:
                idx = list(dict(module.named_children()).keys()).index(token)
            module = module[idx]
        else:
            module = module._modules[token]
    return module


def replace_module(model, name, substitute):
    '''
    Replaces a module selected by name
    '''
    tokens = [e for e in name.split('.') if e]
    module = model
    for token in tokens[:-1]:
        if isinstance(module, torch.nn.Sequential):
            try:
                idx = int(token)
            except ValueError:
                idx = list(dict(module.named_children()).keys()).index(token)
            module = module[idx]
        else:
            module = module._modules[token]

    token = tokens[-1]
    if isinstance(module, torch.nn.Sequential):
        try:
            idx = int(token)
        except ValueError:
            idx = list(dict(module.named_children()).keys()).index(token)
        module[idx] = substitute
    else:
        module._modules[token] = substitute

    return model


def module_out_shape(model, module, gpu=False):
    # Hook
    out = []

    def h(m, i, o):
        out.append(o.data.cpu().numpy())

    hook = module.register_forward_hook(h)

    # Random input
    # TODO: generalize this
    x = torch.rand(16, 3, 224, 224)

    if gpu:
        x = x.cuda()

    # Get module output shape
    model(x)
    hook.remove()

    return out[0].shape


def mask_module(model, module_name, units=[], mu=0.0, gpu=False):
    '''
    Given a PyTorch model, it replaces a
    given module with another that maskes
    the output of some units.
    The output is constrained to mu.
    '''
    if isinstance(units, list):
        units = torch.tensor(units)

    module = get_module(model, module_name)

    # Analyze module shape
    out_shape = module_out_shape(model, module, gpu)
    is_convolutional = len(out_shape) == 4
    n_units = out_shape[1]

    # units[i] = 0 if unit i is masked
    units = [1 if u not in units else 0 for u in range(n_units)]
    # NOTE: handle datatype
    units = torch.tensor(units)

    # bias[i] = \mu if unit i is masked
    bias = [mu if u == 0 else 0.0 for u in units]
    bias = torch.tensor(bias)

    # eye
    eye = torch.eye(n_units)

    if gpu:
        units = units.cuda()
        bias = bias.cuda()
        eye = eye.cuda()

    if not is_convolutional:
        # Define mask for fully connected layer
        mask = torch.nn.Linear(n_units, n_units)
        mask.weight = torch.nn.Parameter(eye * units)
        mask.bias = torch.nn.Parameter(bias)
    else:
        # Define mask for convolutional layer
        mask = torch.nn.Conv2d(n_units, n_units, (1, 1))
        mask.weight = torch.nn.Parameter(torch.reshape(
            eye * units,
            (n_units, n_units, 1, 1)))
        mask.bias = torch.nn.Parameter(bias)

    # Replace module
    module = torch.nn.Sequential(module, mask)
    model = replace_module(model, module_name, module)

    return model


def mask_model(units, model, mu=0.0, gpu=False):
    '''
    Parameters
    ----------
    units: dict of lists
        Dictionary containing for
        each module the units
        that should be removed
    model: torch.nn.Module
        The PyTorch model to analyze
    '''

    # Deep copy the model
    model = copy.deepcopy(model)

    # Remove the units
    for module in units:
        model = mask_module(model,
                            module,
                            units[module],
                            mu, gpu)

    return model


def evaluate_model(model, target_dataset, metric,
                   batch_size=32, silent=False, gpu=False):

    loader = DataLoader(target_dataset,
                        batch_size=batch_size)

    # Evaluation ready
    model.eval()

    with torch.no_grad():
        metric.init()
        for batch in tqdm(loader, disable=silent):
            x, y = batch
            if gpu:
                x = x.cuda()
            y_m = model(x)
            if gpu:
                y_m = y_m.cpu()
            metric.update(y, y_m)
    return metric.finalize()
