from torchvision.transforms import transforms as T
from .src.model.vision_transformer import vit_small


def load_preprocessor_transreid_ssl(h=256, w=128):
    normalizer = T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    transform = T.Compose([
        T.Resize((h, w), interpolation=3),
        T.ToTensor(),
        normalizer
    ])
    return transform


def load_model_transreid_ssl(path):
    model = vit_small(img_size=(256, 128), drop_path_rate=0.3, 
                      pretrained_path=path, hw_ratio=2, conv_stem=True)
    model.cuda()
    model.eval()
    return model
    
