import torch
import torch.nn as nn




class TimmWrapper(nn.Module):

    def __init__(self, model, freeze_bn=False, miro=False):
        super(TimmWrapper,self).__init__()
        self.model = model
        self.freeze_bn = freeze_bn
        self.miro = miro


        if self.freeze_bn:
            self.freeze_batchnorm()

    def forward(self, x):
        return self.model(x)


    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        """
        super().train(mode)
        if self.freeze_bn:
            self.freeze_batchnorm()

    def freeze_batchnorm(self):
        for m in self.model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()



