''' Modified from https://github.com/alinlab/LfF/blob/master/module/util.py '''

import torch.nn as nn
from module.mlp import *
from torchvision.models import resnet18
from module.resnet import *



class resnet_wrap(nn.Module):
    def __init__(self, num_classes, pretrained=False, disent = False):
        super(resnet_wrap, self).__init__()
        self.model = resnet18(pretrained)
        self.model = torch.nn.Sequential(*(list(self.model.children())[:-1]))

        if disent:
            self.fc = nn.Linear(1024, num_classes)
            self.projection_head = nn.Linear(512, 128)
            self.bnl = nn.BatchNorm1d(128)
        else:
            self.fc = nn.Linear(512, num_classes)
            self.projection_head = nn.Linear(512, 128)
            self.bnl = nn.BatchNorm1d(128)

    def forward(self, x, head_ext=False, feat_ext=False):
        feat = self.model(x)
        feat = feat.squeeze()
        if head_ext:
            return self.bnl(self.projection_head(feat))
        elif feat_ext:
            return feat
        else:
            x = self.fc(feat)
            return x

    def head(self, x):
        return self.bnl(self.projection_head(x))
        


def get_model(model_tag, num_classes, pretrain=False):
    if model_tag == "ResNet18":
        model = resnet_wrap(num_classes, pretrain)
        return model
    elif model_tag == 'resnet_DISENTANGLE':
        model = resnet_wrap(num_classes, pretrain, disent=True)
        return model
    else:
        raise NotImplementedError

def get_backbone(model_key, num_classes, pretrained=False, first_stage=False, args=None):
    if model_key == 'ResNet':
        model = resnet18(pretrained=False)
        feature_dim = 512
    elif model_key == 'ResNet18':
        print(f'Resnet18 pretrained {pretrained} loaded...')
        model = resnet18(pretrained=pretrained)
        feature_dim = 512
        if args.train_disent_be and first_stage == False:
            feature_dim = 1024
    if 'ResNet' in model_key:
        model.fc = nn.Linear(feature_dim, num_classes)
    return model


def get_pretrained(num_classes, args):
    model = SupCEResNet('resnet50', num_classes = num_classes, pool=True)
    checkpoint = torch.load(f'./pret_models/resnet50_webvision.pth')
    sd = {}
    for ke in checkpoint['model']:
        nk = ke.replace('module.', '')
        sd[nk] = checkpoint['model'][ke]
    model.load_state_dict(sd, strict=False)
    model.fc = nn.Linear(2048, num_classes)
    return model.cuda()

def remove_fc(model):
    model = torch.nn.Sequential(*(list(model.children())[:-1]))
    return model