import torch
from torch import nn
import numpy as np

import abc

from pruning import utils, modules


class InputCollector(abc.ABC):
    def __init__(self):
        self._struct = None

        self._hooks = []
        self._accum = []

        self._pointer = 0
        self._count = 0

    def detach_hooks(self):
        for hook in self._hooks:
            hook.remove()
        self._hooks = []

    def __layer_hook(self, module, args, y):
        x = args[0].clone()

        if self._struct is not None:
            x = torch.reshape(x, self._struct)
            self._struct = None

        ret = self._layer_hook(module, [x], y)
        self._accum[self._pointer] += ret

        self._pointer = (self._pointer + 1) % len(self._accum)
        if self._pointer == 0:
            self._count += 1


    @abc.abstractmethod
    def _layer_hook(self, module, args, y):
        pass        

    def _struct_hook(self, module, args, output):
        self._struct = output.shape

    def attach_hooks(self, net, module_before_flatten=None):
        for _, mod in net.named_modules():
            if isinstance(mod, nn.Linear) or isinstance(mod, nn.Conv2d):
                handle = mod.register_forward_hook(self.__layer_hook)
                self._hooks.append(handle)
                self._accum.append(0.)
            elif isinstance(mod, module_before_flatten):
                handle = mod.register_forward_hook(self._struct_hook)
                self._hooks.append(handle)
    
    def collect(self):
        return [a / self._count for a in self._accum]

class InputMeanCollector(InputCollector):
    def __init__(self):
        super().__init__()

    def _layer_hook(self, module, args, y):
        return args[0].clone().mean(dim=0)


class InputCorrCollector(InputCollector):
    def __init__(self, means=None):
        super().__init__()

        self.means = means

    def _layer_hook(self, module, args, y):
        x = args[0].clone()
    
        batch_size = len(x) if len(x.shape) <= 2 else len(x) * np.prod(x.shape[2:])

        if self.means is not None:
            x = x - self.means[self._pointer]

        C = (1/batch_size) * torch.einsum('bi...,bj...->ij', x, x)

        return C

class InputVarCollector(InputCollector):
    def __init__(self, means=None):
        super().__init__()

        self.means = means

    def _layer_hook(self, module, args, y):
        x = args[0].clone()
    
        if self.means is not None:
            x = x - self.means[self._pointer]

        x = x @ module.R if len(x.shape) <= 2 else nn.functional.conv2d(x, module.R[..., None, None]).moveaxis(1, 0).flatten(start_dim=1)

        v = torch.mean(x**2, dim=1)
        m = torch.mean(x, dim=1)
        ret = torch.empty((2, len(v)))
        ret[0] = m
        ret[1] = v
        return ret
    
    def attach_hooks(self, net, module_before_flatten=None):
        for _, mod in net.named_modules():
            print(mod, isinstance(mod, modules.FactPrunableLayer), isinstance(mod, modules.FactPrunable))
            if isinstance(mod, modules.Prunable):
                handle = mod.register_forward_hook(self._InputCollector__layer_hook)
                self._hooks.append(handle)
                self._accum.append(0.)
            elif isinstance(mod, module_before_flatten):
                handle = mod.register_forward_hook(self._struct_hook)
                self._hooks.append(handle)