
import numpy as np
import sys
import torch

from DataUtils import get_loader, ImageDataset, DirectoryDataset

class ModelWrapper():

    def __init__(self, model, transform_mode = 'normalize', feature_hook = None, get_id = None):
        self.model = model
        self.transform_mode = transform_mode
        self.feature_hook = feature_hook
        self.get_id = get_id
        
    def predict(self, im):
        if len(im.size()) == 3:
            im = torch.unsqueeze(im, 0)
        return self.model(im.cuda()).cpu().data.numpy()
     
    def predict_dataset(self, files, labels):
        model = self.model
        feature_hook = self.feature_hook
        get_id = self.get_id
        
        dataset = ImageDataset(files, labels, transform_mode = self.transform_mode, get_names = True)
        dataloader = get_loader(dataset)

        out = {}
        for batch in dataloader:
            x = batch[0].cuda()
            y = batch[1].numpy()
            f = batch[2]
            
            y_hat = model(x).data.cpu().numpy()
            if feature_hook is not None:
                rep = feature_hook.features.data.cpu().numpy()[:, :, 0, 0] #This may be specific to ResNets
            
            for i, v in enumerate(f):
                if get_id is not None:
                    v = get_id(v)
                tmp = {}
                tmp['pred'] = y_hat[i, :]
                tmp['label'] = y[i, :]
                if feature_hook is not None:
                    tmp['rep'] = rep[i, :]
                out[v] = tmp            
                
        return out
    
    
    def predict_directory(self, directory):
        model = self.model
        feature_hook = self.feature_hook
        get_id = self.get_id
        
        dataset = DirectoryDataset(directory, transform_mode = self.transform_mode)
        dataloader = get_loader(dataset)

        out = {}
        for batch in dataloader:
            x = batch[0].cuda()
            f = batch[1]
            
            y_hat = model(x).data.cpu().numpy()
            if feature_hook is not None:
                rep = feature_hook.features.data.cpu().numpy()[:, :, 0, 0] #This may be specific to ResNets
            
            for i, v in enumerate(f):
                if get_id is not None:
                    v = get_id(v)
                tmp = {}
                tmp['pred'] = y_hat[i, :]
                if feature_hook is not None:
                    tmp['rep'] = rep[i, :]
                out[v] = tmp                 
                
        return out
