# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model import BaseModule

from mmpretrain.models.heads import ClsHead
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer


class BatchNormLinear(BaseModule):

    def __init__(self, in_channels, out_channels, norm_cfg=dict(type='BN1d')):
        super(BatchNormLinear, self).__init__()
        self.bn = build_norm_layer(norm_cfg, in_channels)
        self.linear = nn.Linear(in_channels, out_channels)

    @torch.no_grad()
    def fuse(self):
        w = self.bn.weight / (self.bn.running_var + self.bn.eps)**0.5
        b = self.bn.bias - self.bn.running_mean * \
            self.bn.weight / (self.bn.running_var + self.bn.eps) ** 0.5
        w = self.linear.weight * w[None, :]
        b = (self.linear.weight @ b[:, None]).view(-1) + self.linear.bias

        self.linear.weight.data.copy_(w)
        self.linear.bias.data.copy_(b)
        return self.linear

    def forward(self, x):
        x = self.bn(x)
        x = self.linear(x)
        return x


def fuse_parameters(module):
    for child_name, child in module.named_children():
        if hasattr(child, 'fuse'):
            setattr(module, child_name, child.fuse())
        else:
            fuse_parameters(child)


@MODELS.register_module()
class LeViTClsHead(ClsHead):

    def __init__(self,
                 num_classes=1000,
                 distillation=True,
                 in_channels=None,
                 deploy=False,
                 **kwargs):
        super(LeViTClsHead, self).__init__(**kwargs)
        self.num_classes = num_classes
        self.distillation = distillation
        self.deploy = deploy
        self.head = BatchNormLinear(in_channels, num_classes)
        if distillation:
            self.head_dist = BatchNormLinear(in_channels, num_classes)

        if self.deploy:
            self.switch_to_deploy(self)

    def switch_to_deploy(self):
        if self.deploy:
            return
        fuse_parameters(self)
        self.deploy = True

    def forward(self, x):
        x = self.pre_logits(x)
        if self.distillation:
            x = self.head(x), self.head_dist(x)  # 2 16 384 -> 2 1000
            if not self.training:
                x = (x[0] + x[1]) / 2
            else:
                raise NotImplementedError("MMPretrain doesn't support "
                                          'training in distillation mode.')
        else:
            x = self.head(x)
        return x
