import torch
from PIL import Image
from collections import OrderedDict
from torchvision.transforms import transforms as T

from .models.resnet import resnet50
from .utils.serialization import load_checkpoint, copy_state_dict
from .utils.faiss_rerank import compute_jaccard_distance


def load_preprocessor_cc(h=256, w=128):
    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    transform = T.Compose([
        T.Resize((h, w), interpolation=3),
        T.ToTensor(),
        normalizer
    ])
    return transform


def load_model_cc(path):
    model = resnet50(num_features=0, norm=True, dropout=0, num_classes=0, pooling_type="gem")
    model.cuda()
    checkpoint = load_checkpoint(path)
    copy_state_dict(checkpoint['state_dict'], model, strip='module.')
    model.eval()
    return model


def compute_jaccard_distance_helper(*args, **kwargs):
    return compute_jaccard_distance(*args, **kwargs)
