import torch
import numpy as np
from torch.autograd import Variable


class PytorchModel(object):
    def __init__(self, model, bounds=[0,1], num_classes=10, im_mean=None, im_std=None):
        self.model = model
        self.model.cuda()
        self.model.eval()
        self.bounds = bounds
        self.num_classes = num_classes
        self.num_queries = 0
        if im_mean is not None and im_std is not None:
            print("Normalization parameters were set above for the model.")
        self.im_mean = im_mean
        self.im_std = im_std
        self.max_batch_size = 128

        
    def preprocess(self, image):
        if isinstance(image, np.ndarray):
            processed = torch.from_numpy(image).type(torch.FloatTensor)
        else:
            processed = image
        
        if len(processed.size()) != 4:
            processed = processed.unsqueeze(0)
                
        processed = processed.cuda()
        processed = torch.clamp(processed, self.bounds[0], self.bounds[1])
        
        if self.im_mean is not None and self.im_std is not None:
            im_mean = torch.tensor(self.im_mean).cuda().view(1, processed.shape[1], 1, 1).repeat(
                processed.shape[0], 1, 1, 1)
            im_std = torch.tensor(self.im_std).cuda().view(1, processed.shape[1], 1, 1).repeat(
                processed.shape[0], 1, 1, 1)
            processed = (processed - im_mean) / im_std
        return processed
    
    def forward(self, image):
        if len(image.size()) != 4:
            image = image.unsqueeze(0)
        image = self.preprocess(image)
        logits = self.model(image)
        return logits

    def predict_prob(self, image):
        with torch.no_grad():
            if isinstance(image, np.ndarray):
                image = torch.from_numpy(image).type(torch.FloatTensor)
            else:
                image = image
            
            if len(image.size()) != 4:
                image = image.unsqueeze(0)
                
            out = None
            for i in np.arange(0, len(image), self.max_batch_size):
                jump = min(i + self.max_batch_size, len(image))
                image_i = self.preprocess(image[i:jump])
                logits = self.model(image_i)
                if len(image) > self.max_batch_size:
                    logits = logits.cpu()
                    
                self.num_queries += image_i.size(0)
                
                if out is None:
                    out = torch.zeros(*[len(image)] + list(logits.shape[1:]))
                    if len(image) <= self.max_batch_size:
                        out = out.cuda()
                        
                out[i:jump] = logits
                
            return out

    def predict_label(self, image):
        logits = self.predict_prob(image)
        _, predict = torch.max(logits, 1)
        return predict

    def get_num_queries(self):
        return self.num_queries
    
    def reset_queries(self):
        self.num_queries = 0
            
