import os
import torch
from torchvision.transforms import transforms as T
from .model import make_model
from .config import cfg


def load_preprocessor_transreid(h=256, w=128):
    normalizer = T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    transform = T.Compose([
        T.Resize((h, w)),
        T.ToTensor(),
        normalizer
    ])
    return transform


def load_model_transreid(path):
    cfg.freeze()
    model = make_model(cfg, num_class=1041, camera_num=15, view_num=1)
    model.load_param(path)
    model.cuda()
    model.eval()
    model.base.cam_num = 0
    model.base.view_num = 0
    return model

class WrapperForFixedCam:
    def __init__(self, model, cid=0):
        self.model = model
        self.cid = cid

    def __call__(self, x):
        cid = self.cid * torch.ones(x, dtype=torch.long).cuda()
        feature = self.model(x, cam_label=cid)
        return feature

    def cuda(self):
        self.model = self.model.cuda()
        return self

    def eval(self):
        self.model.eval()
