import torch
import torch.nn as nn
from torchvision.models import vgg16

from inspect import getfullargspec

from typing import Tuple, List, Union

class Info:
    '''
        Very simplistic implementation of PyTorch hooks that are used to grab information about the
        module layer shape and name.
    '''
    def __init__(self) -> None:
        self.shapes = []
        self.names  = []

        self.traces = {'conv2d' : 0, 'relu' : 0, 'mpool2d' : 0, 'fc' : 0, 'avgpool2d' : 0}

    def __call__(self, module, inp, out):
        self.shapes += [out.detach().cpu().numpy().squeeze().shape]
        self.names  += [self.get_name(module)]

    def get_name(self, module : nn.Module) -> str:
        if   isinstance(module, nn.Conv2d):    return f"conv2d_{self._update('conv2d')}" 
        elif isinstance(module, nn.MaxPool2d): return f"mpool2d_{self._update('mpool2d')}" 
        elif isinstance(module, nn.ReLU):      return f"relu_{self._update('relu')}" 
        elif isinstance(module, nn.Linear):    return f"fc_{self._update('fc')}"
        elif isinstance(module, nn.AdaptiveAvgPool2d): return f"avgpool2d_{self._update('avgpool2d')}"
        else: raise ValueError(f'Unknow module type {module}')

    def _update(self, key : str): 
        self.traces[key] = self.traces[key] + 1
        return str(self.traces[key]).zfill(2)

def get_info(model : nn.Module, inp_shape : Tuple[int], exclude : List[nn.Module] = None) -> List[tuple]:
    exclude = [type(None)] if exclude is None else exclude

    # Get the flattened model, excluding unwanted layer types
    layers = flatten(model)

    # Attach a basic recording hook to each target layer
    info = Info()
    hook_handles = [l.register_forward_hook(info) for l in layers if not isinstance(l, tuple(exclude))]

    # Pass a mock up input into the network to trigger the hooks
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = model.to(device)
    inp = torch.zeros(inp_shape).to(device)

    _ = model(inp)

    # Clean the hooks
    for hook in hook_handles: hook.remove()

    return info

def flatten(model : nn.Module):
    flattened = [flatten(children) for children in model.children()]
    res = [model] if list(model.children()) == [] else []

    for c in flattened: res += c
    
    return res

def replace_module(model : nn.Module, src : nn.Module, trg : nn.Module) -> nn.Module:
    info = getfullargspec(trg)

    for name, layer in model.named_children():
        if isinstance(layer, src): setattr(model, name, trg(**{k : v for k, v in layer.__dict__.items() if k in info.args + info.kwonlyargs}))
        elif list(layer.children()) != []: replace_module(layer, src, trg)
        else: pass

    return model

def generate_hook_keys(model : nn.Module, exclude : List[nn.Module] = None) -> List[str]:
    exclude = [type(None)] if exclude is None else exclude

    layers = [l for l in flatten(model) if not isinstance(l, tuple(exclude))]

    hook_keys = []
    names = {'conv2d' : 0, 'relu' : 0, 'mpool2d' : 0, 'fc' : 0, 'avgpool2d' : 0}
    
    def update(key : str): 
        names[key] = names[key] + 1
        return str(names[key]).zfill(2)

    for l, L in enumerate(layers):
        if   isinstance(L, nn.Conv2d):    hook_keys += [f"{l}_conv2d_{update('conv2d')}"] 
        elif isinstance(L, nn.MaxPool2d): hook_keys += [f"{l}_mpool2d_{update('mpool2d')}"] 
        elif isinstance(L, nn.ReLU):      hook_keys += [f"{l}_relu_{update('relu')}"] 
        elif isinstance(L, nn.Linear):    hook_keys += [f"{l}_fc_{update('fc')}"]
        elif isinstance(L, nn.AdaptiveAvgPool2d): hook_keys += [f"{l}_avgpool2d_{update('avgpool2d')}"]
        else: raise ValueError(f'Unknow layer type {L}')

    return hook_keys

def getnet(
    tlayers : Union[str, int, List[int]] = 'all',
    seed : int = None,
    kind : str = 'vgg16',
    verbose = False
    ) -> torch.nn.Module:
    vprint = print if verbose else lambda x : None

    if seed: torch.manual_seed(seed)

    tnet = vgg16(pretrained = True)
    net  = vgg16(pretrained = False)

    if isinstance(tlayers, str) and tlayers == 'all' : return tnet
    if isinstance(tlayers, str) and tlayers == 'rnd' : return net
    if isinstance(tlayers, int): tlayers = range(tlayers)

    # Transfer the weights from trained to random
    for tl in tlayers:
        try:
            w = tnet.features[tl].weight.data
            net.features[tl].weight.data.copy_(w)

        except AttributeError:
            vprint(f"Skipping layer {tl}:{tnet.features[tl]} as it doesn't have weights.")

        try:
            b = tnet.features[tl].bias.data
            net.features[tl].bias.data.copy_(b)

        except AttributeError:
            vprint(f"Skipping layer {tl}:{tnet.features[tl]} as it doesn't have biases.")

    return net