import os
import torch
from torchvision.transforms import transforms as T
from .models import create
from .models.dsbn import convert_dsbn, convert_bn
from .utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict
from .utils import to_torch
import torch.nn as nn

def load_preprocessor_spcl(h=256, w=128):
    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    transform = T.Compose([
        T.Resize((h, w), interpolation=3),
        T.ToTensor(),
        normalizer
    ])
    return transform

def load_model_spcl(path, is_dsbn=False):
    model = create("resnet50", pretrained=False, num_features=0, dropout=0, num_classes=0)
    if is_dsbn:
        print("==> Load the model with domain-specific BNs")
        convert_dsbn(model)
    
    # Load from checkpoint
    checkpoint = load_checkpoint(path)
    copy_state_dict(checkpoint['state_dict'], model, strip='module.')
    model.cuda()
    model = nn.DataParallel(model)
    model.eval()
    # model = Wrapper(model)
    return model

class Wrapper:
    def __init__(self, model):
        self.model = model

    def __call__(self, x):
        x = to_torch(x).cuda()
        feature = self.model(x)
        return feature

    def cuda(self):
        self.model = self.model.cuda()
        return self

    def eval(self):
        self.model.eval()
