import logging
import math

import torch
import torch.nn as nn


_logger = logging.getLogger(__name__)


class ScaleShiftWrapper(nn.Module):

    def __init__(self, backbone, iter_backbone, head_factory, scale=True, shift=False):

        super().__init__()
        self.backbone = backbone
        self.iter_backbone = iter_backbone
        self.heads = head_factory(backbone)

        self.scale = scale
        self.shift = shift

        scale_params = [nn.ParameterList() for task_idx in range(len(self.heads))]
        shift_params = [nn.ParameterList() for task_idx in range(len(self.heads))]

        for l, (layer_name, layer) in enumerate(self.iter_backbone(self.backbone)):
            try:
                num_features = layer.out_features
                _logger.info(f"Setting {layer_name} out_features")
            except AttributeError:
                try:
                    num_features = layer.out_channels
                    _logger.info(f"Setting {layer_name} with out_channels")
                except AttributeError:
                    num_features = 768
                    _logger.info(f"Setting {layer_name} to default 768.") 
            _logger.info(f"Setting {layer_name}")
            if self.scale:
                for task_idx in range(len(self.heads)):
                    scale_params[task_idx].append(nn.Parameter(torch.ones(1, num_features), requires_grad=True))
            if self.shift:
                for task_idx in range(len(self.heads)):
                    # shift_params[task_idx].append(nn.Parameter(torch.ones(1, num_features), requires_grad=True))
                    shift_params[task_idx].append(nn.Parameter(torch.zeros(1, num_features), requires_grad=True))

            layer.register_forward_hook(self.scale_shift_hook(l))

        self.scale_params = nn.ModuleList(scale_params)
        self.shift_params = nn.ModuleList(shift_params)

    def scale_shift_hook(self, l):
        
        def hook(module, input, output):
            if self.scale:
                output = output * self.scale_params[self.task_idx][l]
            if self.shift:
                output = output + self.shift_params[self.task_idx][l]

            return output
        
        return hook

    def set_task_idx(self, task_idx):
        self.task_idx = task_idx
    
    def forward(self, x):
        x = self.backbone.forward_features(x)
        x = self.backbone.forward_head(x, pre_logits=True)
        return self.heads[self.task_idx](x)
