import torch
from torch import nn
import numpy as np

import abc


class InputCollector(abc.ABC):
    def __init__(self):
        self._struct = None

        self._hooks = []
        self._accum = []

        self._pointer = 0
        self._count = 0

    def detach_hooks(self):
        for hook in self._hooks:
            hook.remove()
        self._hooks = []

    def __layer_hook(self, module, args, y):
        x = args[0].clone()

        if self._struct is not None:
            x = torch.reshape(x, self._struct)
            self._struct = None

        ret = self._layer_hook(module, [x], y)
        self._accum[self._pointer] += ret

        self._pointer = (self._pointer + 1) % len(self._accum)
        if self._pointer == 0:
            self._count += 1


    @abc.abstractmethod
    def _layer_hook(self, module, args, y):
        pass        

    def _struct_hook(self, module, args, output):
        self._struct = output.shape

    def attach_hooks(self, net, module_before_flatten=None):
        for _, mod in net.named_modules():
            if isinstance(mod, (nn.Linear, nn.Conv2d)):
                handle = mod.register_forward_hook(self.__layer_hook)
                self._hooks.append(handle)
                self._accum.append(0.)
            elif isinstance(mod, module_before_flatten):
                handle = mod.register_forward_hook(self._struct_hook)
                self._hooks.append(handle)
    
    def collect(self):
        return [a / self._count for a in self._accum]

class InputMeanCollector(InputCollector):
    def __init__(self):
        super().__init__()

    def _layer_hook(self, module, args, y):
        if len(args[0].shape) == 3:
            return args[0].clone().mean(dim=(0, 1))
        else:
            return args[0].clone().mean(dim=0)


class InputCorrCollector(InputCollector):
    def __init__(self, means=None):
        super().__init__()

        self.means = means

    def _layer_hook(self, module, args, y):
        x = args[0].clone()
    
        batch_size = len(x) if len(x.shape) <= 2 else len(x) * np.prod(x.shape[2:])

        if self.means is not None:
            x = x - self.means[self._pointer]

        C = (1/batch_size) * torch.einsum('bi...,bj...->ij', x, x)

        return C
    

class DeitInputCorrCollector(InputCollector):
    def __init__(self, means=None):
        super().__init__()

        self.means = means

    def _layer_hook(self, module, args, y):
        x = args[0].clone()

        if len(x.shape) == 2:
            if self.means is not None:
                x = x - self.means[self._pointer][None, ...]
            C = (1/len(x)) * torch.einsum('bi,bj->ij', x, x)
        elif len(x.shape) == 3:
            if self.means is not None:
                x = x - self.means[self._pointer][None, None, ...]
            C = (1/np.prod(x.shape[:-1])) * torch.einsum('bni,bnj->ij', x, x)
        else:
            if self.means is not None:
                x = x - self.means[self._pointer][None, ...]
            C = 1/(len(x) * np.prod(x.shape[2:])) * torch.einsum('bi...,bj...->ij', x, x)

        return C
