import logging
import math

import torch
import torch.nn as nn


_logger = logging.getLogger(__name__)


class SupSupWrapper(nn.Module):

    def __init__(self, backbone, iter_backbone, head_factory):

        super().__init__()
        self.backbone = backbone
        self.iter_backbone = iter_backbone
        self.heads = head_factory(backbone)

        for layer_name, layer in self.iter_backbone(self.backbone):
            for score in layer.scores:
                _logger.info(f"Setting {layer_name}.{len(score)} to require grad")
                score.requires_grad = True

    def set_task_idx(self, task_idx):
        self.task_idx = task_idx
        for layer_name, layer in self.iter_backbone(self.backbone):
            layer.set_task_idx(task_idx)
    
    def forward(self, x):
        x = self.backbone.forward_features(x)
        x = self.backbone.forward_head(x, pre_logits=True)
        return self.heads[self.task_idx](x)

    def projection_vis_hook(self, l):

        def hook(module, input, output):
            
            self.projection_features[l] = output.detach()

        return hook
    
    def attach_projection_vis_hooks(self):
        
        self.projection_hooks = []
        self.projection_features = []
        for l, (layer_name, layer) in enumerate(self.iter_backbone(self.backbone)):
            self.projection_features.append(None)
            self.projection_hooks.append(layer.register_forward_hook(self.projection_vis_hook(l)))

    def reset_projection_features(self):
        _len = len(self.projection_features)
        self.projection_features = [None] * _len

    def attach_attn_vis_hooks(self):

        self.attn_hooks = []
        self.attn_features = []
        for l, block in enumerate(self.backbone.blocks):
            self.attn_features.append(None)
            self.attn_hooks.append(block.attn.register_forward_hook(self.attn_vis_hook(l)))

    def attn_vis_hook(self, l):

        def hook(module, input, output):
            _x, attn = output
            self.attn_features[l] = attn[:, :, 0, 1:].detach()

        return hook
    
    def reset_attn_features(self):
        _len = len(self.attn_features)
        self.attn_features = [None] * _len

    def attach_block_vis_hooks(self):

        self.block_hooks = []
        self.block_features = []
        for l, block in enumerate(self.backbone.blocks):
            self.block_features.append(None)
            self.block_hooks.append(block.register_forward_hook(self.block_vis_hook(l)))

    def block_vis_hook(self, l):

        def hook(module, input, output):
            self.block_features[l] = output.detach()

        return hook
    
    def reset_block_features(self):
        _len = len(self.block_features)
        self.block_features = [None] * _len

    def attach_mlp_vis_hooks(self):

        self.mlp_hooks = []
        self.mlp_features = []
        for l, block in enumerate(self.backbone.blocks):
            self.mlp_features.append(None)
            self.mlp_hooks.append(block.mlp.register_forward_hook(self.mlp_vis_hook(l)))

    def mlp_vis_hook(self, l):

        def hook(module, input, output):
            self.mlp_features[l] = output.detach()

        return hook
    
    def reset_mlp_features(self):
        _len = len(self.mlp_features)
        self.mlp_features = [None] * _len
        
    def attach_attn_res_vis_hooks(self):

        self.attn_res_hooks = []
        self.attn_res_features = []
        for l, block in enumerate(self.backbone.blocks):
            self.attn_res_features.append(None)
            self.attn_res_hooks.append(block.identity.register_forward_hook(self.attn_res_vis_hook(l)))

    def attn_res_vis_hook(self, l):

        def hook(module, input, output):
            self.attn_res_features[l] = output.detach()

        return hook
    
    def reset_attn_res_features(self):
        _len = len(self.attn_res_features)
        self.attn_res_features = [None] * _len
