import os
import torch
from torchvision.transforms import transforms as T
import torch.nn as nn

from .config import cfg
from .model import make_model

def load_preprocessor_pat(h=256, w=128):
    normalizer = T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    transform = T.Compose([
        T.Resize((h, w), interpolation=3),
        T.ToTensor(),
        normalizer
    ])
    return transform

def load_model_pat(path, use_bnneck=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create model
    cfg.merge_from_file(os.path.join(os.path.dirname(os.path.abspath(__file__)), "config/PAT.yml"))
    model = make_model(cfg, cfg.MODEL.NAME, 0,0,0)
    model.load_param(path)

    device = "cuda"
    if device:
        if torch.cuda.device_count() > 1:
            print('Using {} GPUs for inference'.format(torch.cuda.device_count()))
            model = nn.DataParallel(model)
        model.to(device)
    model.eval()
    return model

class Wrapper:
    def __init__(self, model):
        self.model = model

    def __call__(self, x):
        x = to_torch(x).cuda()
        feature = self.model(x)
        return feature

    def cuda(self):
        self.model = self.model.cuda()
        return self

    def eval(self):
        self.model.eval()
