import os
from torchvision.transforms import transforms as T
from .model import make_model
from .config import cfg


def load_preprocessor_solider(h=384, w=128):
    normalizer = T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    transform = T.Compose([
        T.Resize((h, w)),
        T.ToTensor(),
        normalizer
    ])
    return transform


def load_model_solider(path):
    cfg.freeze()
    model = make_model(cfg, num_class=1041, camera_num=15, view_num=1, semantic_weight=0.2)
    model.load_param(path)
    model.cuda()
    model.eval()
    return model
    