import torch
import torch.nn as nn
import torch.nn.functional as F

class Wrapper(nn.Module):
    def __init__(self, module, last_feat=True):
        super(Wrapper, self).__init__()

        self.backbone = module
        classifier_module = list(module.children())[-1]
        if isinstance(classifier_module, nn.Sequential):
            classifier_module = list(classifier_module.children())[0]
        feat_dim = classifier_module.in_features
        if feat_dim <= 4096:
            self.proj_head = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feat_dim, feat_dim),
                nn.ReLU(inplace=True),
                nn.Linear(feat_dim, feat_dim)
            )
        else:
            self.proj_head = nn.Sequential(
                nn.Linear(feat_dim, 4096),
                nn.ReLU(inplace=True),
                nn.Linear(4096, 4096),
                nn.ReLU(inplace=True),
                nn.Linear(4096, 4096)
            )
        self.last_feat = last_feat

    def forward(self, x, bb_grad=True):
        
        out, feat = self.backbone(x, is_feat=True)
        feat = feat.view(feat.size(0), -1)
        if not bb_grad:
            feat = feat.detach()

        return out, self.proj_head(feat), feat