import torch.nn
from torchvision.transforms import transforms as T

from .models import resnet50part
from .utils.serialization import load_checkpoint, copy_state_dict


def load_preprocessor_pplr(h=384, w=128):
    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    transform = T.Compose([
        T.Resize((h, w), interpolation=3),
        T.ToTensor(),
        normalizer
    ])

    return transform


def load_model_pplr(path):
    model = resnet50part(num_parts=3, num_classes=3000)
    model.cuda()
    checkpoint = load_checkpoint(path)
    copy_state_dict(checkpoint, model, strip='module.')
    model.eval()
    return model

