"""
Library for extracting interesting quantites from autograd, see README.md

Not thread-safe because of module-level variables

Notation:
o: number of output classes (exact Hessian), number of Hessian samples (sampled Hessian)
n: batch-size
do: output dimension (output channels for convolution)
di: input dimension (input channels for convolution)
Hi: per-example Hessian of matmul, shaped as matrix of [dim, dim], indices have been row-vectorized
Hi_bias: per-example Hessian of bias
Oh, Ow: output height, output width (convolution)
Kh, Kw: kernel height, kernel width (convolution)


A, activations: inputs into current layer
B, backprops: backprop values (aka Lop aka Jacobian-vector product) observed at current layer

"""

from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.unfold import unfoldNd  # To support 3D convolution

_supported_layers = ['Linear', 'Conv2d', 'Conv3d']  # Supported layer class types
_hooks_disabled: bool = False           # work-around for https://github.com/pytorch/pytorch/issues/25723
_enforce_fresh_backprop: bool = False   # global switch to catch double backprop errors on Hessian computation


def add_hooks(model: nn.Module, only_enc: bool = False) -> None:
    """
    Adds hooks to model to save activations and backprop values.

    The hooks will
    1. save activations into param.activations during forward pass
    2. append backprops to params.backprops_list during backward pass.

    Call "remove_hooks(model)" to disable this.

    Args:
        model:
    """

    global _hooks_disabled
    _hooks_disabled = False

    handles = []
    for name, layer in model.named_modules():
        if only_enc:
            if not ('.blocks' in name or 'patch_embed' in name):
                continue

        if _layer_type(layer) in _supported_layers:
            if not layer.weight.requires_grad:
                continue
            handles.append(layer.register_forward_hook(_capture_activations))
            handles.append(layer.register_backward_hook(_capture_backprops))

    model.__dict__.setdefault('autograd_hacks_hooks', []).extend(handles)


def remove_hooks(model: nn.Module) -> None:
    """
    Remove hooks added by add_hooks(model)
    """

    assert model == 0, "not working, remove this after fix to https://github.com/pytorch/pytorch/issues/25723"

    if not hasattr(model, 'autograd_hacks_hooks'):
        print("Warning, asked to remove hooks, but no hooks found")
    else:
        for handle in model.autograd_hacks_hooks:
            handle.remove()
        del model.autograd_hacks_hooks


def disable_hooks() -> None:
    """
    Globally disable all hooks installed by this library.
    """

    global _hooks_disabled
    _hooks_disabled = True


def enable_hooks() -> None:
    """the opposite of disable_hooks()"""

    global _hooks_disabled
    _hooks_disabled = False


def is_supported(layer: nn.Module) -> bool:
    """Check if this layer is supported"""

    return _layer_type(layer) in _supported_layers


def _layer_type(layer: nn.Module) -> str:
    return layer.__class__.__name__


def _capture_activations(layer: nn.Module, input: List[torch.Tensor], output: torch.Tensor):
    """Save activations into layer.activations in forward pass"""

    if _hooks_disabled:
        return
    assert _layer_type(layer) in _supported_layers, "Hook installed on unsupported layer, this shouldn't happen"
    if not hasattr(layer, 'activations_list'):
        setattr(layer, 'activations_list', [])
    layer.activations_list.append(input[0].detach())

def _capture_backprops(layer: nn.Module, _input, output):
    """Append backprop to layer.backprops_list in backward pass."""
    global _enforce_fresh_backprop

    if _hooks_disabled:
        return

    if _enforce_fresh_backprop:
        assert not hasattr(layer, 'backprops_list'), "Seeing result of previous backprop, use clear_backprops(model) to clear"
        _enforce_fresh_backprop = False

    if not hasattr(layer, 'backprops_list'):
        setattr(layer, 'backprops_list', [])
    layer.backprops_list.append(output[0].detach())


def clear_backprops(model: nn.Module) -> None:
    """Delete layer.backprops_list in every layer."""
    for layer in model.modules():
        if hasattr(layer, 'backprops_list'):
            del layer.backprops_list
        if hasattr(layer, 'activations_list'):
            del layer.activations_list


def compute_grad1(model: nn.Module, only_enc: bool = False, loss_type: str = 'mean') -> None:
    """
    Compute per-example gradients and save them under 'param.grad1'. Must be called after loss.backprop()

    Args:
        model:
        loss_type: either "mean" or "sum" depending whether backpropped loss was averaged or summed over batch
    """

    assert loss_type in ('sum', 'mean')
    for name, layer in model.named_modules():
        if only_enc:
            if not ('.block' in name or 'patch_embed' in name):
                continue

        layer_type = _layer_type(layer)
        if layer_type not in _supported_layers:
            continue

        if not layer.weight.requires_grad:
            continue

        assert hasattr(layer, 'activations_list'), f"No activations detected in {name}, run forward after add_hooks(model)"
        assert hasattr(layer, 'backprops_list'), f"No backprops detected in {name}, run backward after add_hooks(model)"
        assert len(layer.backprops_list) == len(layer.activations_list), f"Number of input and gradient in {name} is different, make sure to call clear_backprops(model)"

        for idx in range(len(layer.backprops_list)):

            A = layer.activations_list[idx]
            n = A.shape[0]
            # Since backpropagation have opposite direction.
            if loss_type == 'mean':
                B = layer.backprops_list[-idx-1] * n
            else:  # loss_type == 'sum':
                B = layer.backprops_list[-idx-1]

            if layer_type == 'Linear':
                if idx == 0:
                    if len(A.size()) == 3:
                        setattr(layer.weight, 'grad1', torch.einsum('ndi,ndj->nij', B, A))
                        if layer.bias is not None:
                            setattr(layer.bias, 'grad1', torch.sum(B, dim=1))
                    else:
                        setattr(layer.weight, 'grad1', torch.einsum('ni,nj->nij', B, A))
                        if layer.bias is not None:
                            setattr(layer.bias, 'grad1', B)
                else:
                    if len(A.size()) == 3:
                        layer.weight.grad1 += torch.einsum('ndi,ndj->nij', B, A)
                        if layer.bias is not None:
                            layer.bias.grad1 += torch.sum(B, dim=1)
                    else:
                        layer.weight.grad1 += torch.einsum('ni,nj->nij', B, A)
                        if layer.bias is not None:
                            layer.bias.grad1 += torch.sum(B, dim=1)

            elif layer_type == 'Conv2d':
                A = torch.nn.functional.unfold(A, layer.kernel_size, stride=layer.kernel_size)
                B = B.reshape(n, -1, A.shape[-1])
                grad1 = torch.einsum('ijk,ilk->ijl', B, A)
                shape = [n] + list(layer.weight.shape)
                if idx == 0:
                    setattr(layer.weight, 'grad1', grad1.reshape(shape))
                    if layer.bias is not None:
                        setattr(layer.bias, 'grad1', torch.sum(B, dim=2))
                else:
                    layer.weight.grad1 += grad1.reshape(shape)
                    if layer.bias is not None:
                        layer.bias.grad1 += torch.sum(B, dim=2)

            elif layer_type == 'Conv3d':
                A = unfoldNd(A, layer.kernel_size, stride=layer.kernel_size)
                B = B.reshape(n, -1, A.shape[-1])
                grad1 = torch.einsum('ijk,ilk->ijl', B, A)
                shape = [n] + list(layer.weight.shape)
                if idx == 0:
                    setattr(layer.weight, 'grad1', grad1.reshape(shape))
                    if layer.bias is not None:
                        setattr(layer.bias, 'grad1', torch.sum(B, dim=2))
                else:
                    layer.weight.grad1 += grad1.reshape(shape)
                    if layer.bias is not None:
                        layer.bias.grad1 += torch.sum(B, dim=2)



