import os
import torch
from torchvision.transforms import transforms as T
import torch.nn as nn
from .src.models import stb_net

def load_preprocessor_cap(h=256, w=128):
    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    transform = T.Compose([
        T.Resize((h, w), interpolation=3),
        T.ToTensor(),
        normalizer
    ])
    return transform

def load_model_cap(path, use_bnneck=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create model
    model = stb_net.MemoryBankModel(out_dim=2048, use_bnneck=use_bnneck)

    # # Create memory bank
    # cap_memory = CAPMemory(beta=0.07, alpha=0.2, all_img_cams=dataset.target_train_all_img_cams)

    # Set model
    model = nn.DataParallel(model.to(device))
    # cap_memory = cap_memory.to(device)

    # Load from checkpoint
    trained_dict = torch.load(path)
    filtered_trained_dict = {k: v for k, v in trained_dict.items() if not k.startswith('module.classifier')}
    for k in filtered_trained_dict.keys():
        if 'embeding' in k:
            print('pretrained model has key= {}'.format(k))
    model_dict = model.state_dict()
    model_dict.update(filtered_trained_dict)
    model.load_state_dict(model_dict)

    # model.cuda()
    model.eval()
    # model = Wrapper(model)
    return model

class Wrapper:
    def __init__(self, model):
        self.model = model

    def __call__(self, x):
        x = to_torch(x).cuda()
        feature = self.model(x)
        return feature

    def cuda(self):
        self.model = self.model.cuda()
        return self

    def eval(self):
        self.model.eval()
