from re import I
import numpy as np
from itertools import product
from collections import deque

import torch
import warnings
import torch.nn as nn

from typing import List, Tuple, Union

from utils.miscellaneous import flatten
from utils.miscellaneous import generate_hook_keys

class FeatureRecorder:
    ''' 
        Basic feature recording class that implements a PyTorch hook to
        acquire a layer activations as they get processed by the network
    '''
    
    def __init__(self, keys : List[str], units : List[np.ndarray] = None) -> None:
        err_msg = 'Provided Keys and Units lists have different lengths'
        if units is not None: assert len(keys) == len(units), f'{err_msg}: {len(keys)} and {len(units)}'

        # This is the active key on which to record. Note that the
        # provided list should have the same order as the calling
        # stack of layers one is recording on, as for each call this
        # list is rolled to prepare for next layer
        self.keys = deque (keys)

        # This is the units index list. It is used to index the unit
        # tensor of each recorded layer, thus allowing recording for
        # a subset of the full layer.
        self.units = None if units is None else deque (units)
        
        # Here we define the features dictionary of the recorder.
        # The dictionary is index by a key labeling the layer it
        # refers to and contains a list with all the activations
        # to which it was exposed to.
        self.features = {k : [] for k in self.keys}
        
        
    def __call__(self, module, inp, out):
        # Detach layer output from PyTorch graph and convert to numpy array 
        data = out.detach().cpu().numpy().squeeze()

        # If a units array was registered, use it to index current layer
        if self.units is not None: 
            uidx = self.units[0]
            
            if   uidx.shape[1] == 3: data = data[..., uidx[:, 0], uidx[:, 1], uidx[:, 2]] 
            elif uidx.shape[1] == 2: data = data[..., uidx[:, 0], uidx[:, 1]]
            elif uidx.shape[1] == 1: data = data[..., uidx[:, 0]]
            else: raise ValueError(f'Unknown unit indexing shape: {uidx.shape}') 

        # Register the layer output in active key entries
        self.features[self.keys[0]] += [data]
        
        # Cycle through the keys to prepare for next layer
        self.keys.rotate(-1)
        if self.units is not None: self.units.rotate(-1)

    def clean(self) -> None:
        self.features = {k : [] for k in self.keys}

    def __del__ (self):
        # Here we manually signal to gc to delete the feature list
        del self.features

class Traced(nn.Module):
    '''
        A wrapper class that represents a model whose activity is traced via recording
        hooks. This class accepts a general torch.Model as input a attaches a registering
        hooks to its layer. The registered activations are stored in a dictionary, whose
        indexing keys must be specified by the user. Moreover, user can optionally exclude
        specific type of torch.Module to exclude from recording (ex: Dropout or BatchNorm).
    '''

    def __init__(self, module : nn.Module, 
                       keys : List[str] = None,
                       units : List[np.array] = None, 
                       exclude : List[nn.Module] = None):
        super(Traced, self).__init__()

        self.module  = module
        self.units   = units
        self.exclude = [type(None)] if exclude is None else exclude 
        self.keys    = generate_hook_keys(module, exclude = self.exclude) if keys is None else keys

        # Create a HookRecorder object that will be attached to the module layers as hook
        self.recorder = FeatureRecorder(self.keys, self.units)

        # Get a reference to the complete set of module layers. We do so by flattening the
        # module object to make sure that only leaf-like submodules are considered (ex: no
        # Sequential or nested modules would be considered)
        self.layers = flatten(module)

        # Register the forward hooks for each layered targeted as 'traced'
        self.hook_handles = [l.register_forward_hook(self.recorder) for l in self.layers if not isinstance(l, tuple(self.exclude))]

        if len(self.keys) != len(self.hook_handles):
            warn_msg = f'Key list and Hook Handles have different length. Got {len(self.keys)} and {len(self.hook_handles)}'
            warnings.warn(warn_msg, UserWarning)

    def forward(self, x : torch.FloatTensor) -> Tuple[torch.FloatTensor, dict]:
        out = self.module(x)

        return out, self.recorder.features

    def clean(self) -> None:
        self.recorder.clean()

    def trace(self) -> None:
        # Clean all hooks that might be pending
        self.untrace()

        # Register the forward hooks for each layered targeted as 'traced'
        self.hook_handles = [l.register_forward_hook(self.recorder) for l in self.layers if not isinstance(l, tuple(self.exclude))]

    def untrace(self) -> None:
        # We unregistered the hooks upon untracing
        for hook in self.hook_handles: hook.remove()

        self.hook_handles = []

    def __del__(self) -> None:
        # We unregistered the hooks upon destruction
        for hook in self.hook_handles: hook.remove()
