import torch
import numpy as np
from torch.autograd import Variable
import os


class PytorchCodec(object):
    def __init__(self, model, bounds, im_mean=None, im_std=None):
        self.model = model
        self.model.eval()
        self.model.cuda()
        self.bounds = bounds
        if im_mean is not None and im_std is not None:
            print("Normalization parameters were set above for a codec.")
        self.im_mean = im_mean
        self.im_std = im_std
        self.max_batch_size = 128
        
    def preprocess(self, image):
        if isinstance(image, np.ndarray):
            processed = torch.from_numpy(image).type(torch.FloatTensor)
        else:
            processed = image
        
        if len(processed.size()) != 4:
            processed = processed.unsqueeze(0)
            
        processed = processed.cuda()
        processed = torch.clamp(processed, self.bounds[0], self.bounds[1])
        
        if self.im_mean is not None and self.im_std is not None:
            im_mean = torch.tensor(self.im_mean).cuda().view(1, processed.shape[1], 1, 1).repeat(
                processed.shape[0], 1, 1, 1)
            im_std = torch.tensor(self.im_std).cuda().view(1, processed.shape[1], 1, 1).repeat(
                processed.shape[0], 1, 1, 1)
            processed = (processed - im_mean) / im_std
        return processed
    
    def forward(self, image):
        if len(image.size()) != 4:
            image = image.unsqueeze(0)
            
        out = None
        for i in np.arange(0, len(image), self.max_batch_size):
            jump = min(i + self.max_batch_size, len(image))
            image_i = self.preprocess(image[i:jump])
            embed = self.model(image_i)
            if len(image) > self.max_batch_size:
                embed = embed.cpu()
                    
            if out is None:
                out = torch.zeros(*[len(image)] + list(embed.shape[1:]))
                if len(image) <= self.max_batch_size:
                    out = out.cuda()
                    
            out[i:jump] = embed
        
        return out
    
    def load_state_dict(self, fpath):
        assert os.path.exists(fpath), f"No file at {fpath}"
        self.model.load_state_dict(torch.load(fpath))
    
    def __call__(self, image, class_conditional=None):
        return self.forward(image)
    

class PytorchEncoder(PytorchCodec):
    pass


class PytorchDecoder(PytorchCodec):
    def preprocess(self, image):
        return torch.clamp(image, self.bounds[0], self.bounds[1]).cuda()
        
    def postprocess(self, image):
        processed = image
        if self.im_mean is not None and self.im_std is not None:
            im_mean = torch.tensor(self.im_mean).cuda().view(1, processed.shape[1], 1, 1).repeat(
                processed.shape[0], 1, 1, 1)
            im_std = torch.tensor(self.im_std).cuda().view(1, processed.shape[1], 1, 1).repeat(
                processed.shape[0], 1, 1, 1)
            processed = (processed * im_std) + im_mean
        
        return torch.clamp(processed, 0., 1.).cuda()
        # return torch.clamp(processed, self.bounds[0], self.bounds[1]).cuda()
    
    def forward(self, latent):
        if len(latent.size()) != 4:
            latent = latent.unsqueeze(0)
        
        out = None
        for i in np.arange(0, len(latent), self.max_batch_size):
            jump = min(i + self.max_batch_size, len(latent))
            latent_i = self.preprocess(latent[i:jump])
            processed_i = self.model(latent_i)
            processed_i = self.postprocess(processed_i)
            
            if len(latent) > self.max_batch_size:
                processed_i = processed_i.cpu()

            if out is None:
                out = torch.zeros(*[len(latent)] + list(processed_i.shape[1:]))
                if len(latent) <= self.max_batch_size:
                    out = out.cuda()
            
            out[i:jump] = processed_i
            
        # processed = self.preprocess(latent)
        # processed = self.model(processed)
        # processed = self.postprocess(processed)
            
        return out
    
    
class ClassConditionalEncoder(object):
    """
    Wraps a dictionary of PytorchEncoder objects.
    """
    def __init__(self, models_dict):
        self.known_classes = list(models_dict.keys())
        self.k_to_models = models_dict
    
    def forward(self, image, yi):
        return self.k_to_models[yi].forward(image)
    
    def load_state_dict(self, path_template):
        for k in self.known_classes:
            # replace with class index
            parts = path_template.split('/')
            real_path = os.path.join('/', *parts[:-1], parts[-1].replace('CX', str(k)))
            self.k_to_models[k].load_state_dict(real_path)

    def __call__(self, image, class_conditional):
        return self.forward(image, class_conditional)

    
class ClassConditionalDecoder(ClassConditionalEncoder):
    """
    Wraps a dictionary of PytorchDecoder objects.
    """
    pass
