import torch
from torch import nn
import numpy as np
from tqdm import tqdm

import abc

class InputCollector(abc.ABC):
    def __init__(self):
        self._struct = None

        self._hooks = []
        self._accum = []

        self._pointer = 0
        self._count = 0

    def detach_hooks(self):
        for hook in self._hooks:
            hook.remove()
        self._hooks = []

    def _layer_hook(self, module, args, y):
        x = args[0].clone()

        if self._struct is not None:
            x = torch.reshape(x, self._struct)
            self._struct = None

        ret = self._layer_hook_impl(module, [x], y)
        self._accum[self._pointer] += ret

        self._pointer = (self._pointer + 1) % len(self._accum)
        if self._pointer == 0:
            self._count += 1


    @abc.abstractmethod
    def _layer_hook_impl(self, module, args, y):
        pass        

    def _struct_hook(self, module, args, output):
        self._struct = output.shape

    def attach_hooks(self, net, module_before_flatten=None):
        for _, mod in net.named_modules():
            if isinstance(mod, (nn.Linear, nn.Conv2d)):
                handle = mod.register_forward_hook(self._layer_hook)
                self._hooks.append(handle)
                self._accum.append(0.)
            elif isinstance(mod, module_before_flatten):
                handle = mod.register_forward_hook(self._struct_hook)
                self._hooks.append(handle)
    
    def collect(self):
        return [a / self._count for a in self._accum]


class DeitInputCorrCollector(InputCollector):
    def __init__(self, means=None):
        super().__init__()

        self.means = means

    def _layer_hook_impl(self, module, args, y):
        x = args[0].clone()

        if len(x.shape) == 2:
            if self.means is not None:
                x = x - self.means[self._pointer][None, ...]
            C = (1/len(x)) * torch.einsum('bi,bj->ij', x, x)
        elif len(x.shape) == 3:
            if self.means is not None:
                x = x - self.means[self._pointer][None, None, ...]
            C = (1/np.prod(x.shape[:-1])) * torch.einsum('bni,bnj->ij', x, x)
        else:
            if self.means is not None:
                x = x - self.means[self._pointer][None, ...]
            C = 1/(len(x) * np.prod(x.shape[2:])) * torch.einsum('bi...,bj...->ij', x, x)

        return C
    

class OPTCollector(InputCollector):

    def _layer_hook(self, module, args, y):
        x = args[0].clone()

        if self._struct is not None:
            x = torch.reshape(x, self._struct)
            self._struct = None

        ret = self._layer_hook_impl(module, [x], y)
        if isinstance(self._accum[self._pointer], float):
            self._accum[self._pointer] = ret.cpu()
        else:
            accum = self._accum[self._pointer].cuda()
            self._accum[self._pointer] = (accum + ret).cpu()

        self._pointer = (self._pointer + 1) % len(self._accum)
        if self._pointer == 0:
            self._count += 1

    def _layer_hook_impl(self, module, args, y):
        x = args[0].clone()

        if len(x.shape) == 2:
            C = (1/len(x)) * torch.einsum('bi,bj->ij', x, x)
        elif len(x.shape) == 3:
            C = (1/np.prod(x.shape[:-1])) * torch.einsum('bni,bnj->ij', x, x)
        else:
            C = 1/(len(x) * np.prod(x.shape[2:])) * torch.einsum('bi...,bj...->ij', x, x)

        return C
    
    
    def attach_hooks(self, net, module_before_flatten=None):
        for name, mod in net.named_modules():
            if isinstance(mod, (nn.Linear, nn.Conv2d)) and ('out_proj' in name or 'fc2' in name):
                handle = mod.register_forward_hook(self._layer_hook)
                self._hooks.append(handle)
                self._accum.append(0.)
