# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import numpy as np

scale_type = 'bilinear'
align_corners = True


class MNIST_Downsampler(nn.Module):
    def __init__(self, resize_dim):
        super().__init__()
        self.image_space_channels = 1
        self.resize_dim = resize_dim
        
    def forward(self, x):
        x = nn.functional.interpolate(x, size=(self.resize_dim, self.resize_dim), mode=scale_type, align_corners=align_corners)
        return x


class MNIST_Upsampler(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_space_channels = 1

    def forward(self, x):
        n = len(x)
        y = nn.functional.interpolate(x, size=(28, 28), mode=scale_type, align_corners=align_corners)
        return y


class CIFAR_Downsampler(nn.Module):
    def __init__(self, resize_dim):
        super().__init__()
        self.image_space_channels = 3
        self.resize_dim = resize_dim
        
    def forward(self, x):
        x = nn.functional.interpolate(x, size=(self.resize_dim, self.resize_dim), mode=scale_type, align_corners=align_corners)
        return x


class CIFAR_Upsampler(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_space_channels = 3

    def forward(self, x):
        n = len(x)
        y = nn.functional.interpolate(x, size=(32, 32), mode=scale_type, align_corners=align_corners)
        return y
    
    
class CIFAR_RandomDownsampler(nn.Module):
    def __init__(self, resize_dim):
        super().__init__()
        self.image_space_channels = 3
        self.resize_dim = resize_dim
        self.coordinates = None
        self.original = None
        self.orig_dim = 32
        
    def forward(self, x):
        # sample s'*s' x and y
        coords = np.random.choice(np.arange(self.orig_dim), 2*self.resize_dim*self.resize_dim)
        coords = coords.reshape((-1, 2))
        self.coordinates = coords
        self.original = x
        subsample = x[0, :, coords[:, 0], coords[:, 1]]
        subsample = subsample.reshape((3, self.resize_dim, self.resize_dim))
        
        return subsample.unsqueeze(0)


class CIFAR_RandomUpsampler(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_space_channels = 3
        self.coordinates = None
        self.original = None
        
    def forward(self, x):
        resize_dim = x.shape[-1]
        if resize_dim == 32:
            # already in image space, pass-through similar to Upsampler
            return x
        
        coords = self.coordinates
        new_x = self.original
        subsample = x[0].reshape((3, resize_dim*resize_dim))
        new_x[0, :, coords[:, 0], coords[:, 1]] = subsample
            
        return new_x
    
    def update_original(self, x):
        x = x.clone().detach()
        if len(x.size()) != 4:
            x = x.unsqueeze(0)
        self.original = x
        
    def update_coordinates(self, coords):
        self.coordinates = coords


class Imagenet_Downsampler(CIFAR_Downsampler):
    def forward(self, x):
        x = nn.functional.interpolate(x, size=(self.resize_dim, self.resize_dim), mode=scale_type, align_corners=align_corners)
        return x

    
class Imagenet_Upsampler(CIFAR_Upsampler):
    def forward(self, x):
        n = len(x)
        y = nn.functional.interpolate(x, size=(224, 224), mode=scale_type, align_corners=align_corners)
        return y
    
    
class Imagenet_RandomDownsampler(nn.Module):
    def __init__(self, resize_dim):
        super().__init__()
        self.image_space_channels = 3
        self.resize_dim = resize_dim
        self.coordinates = None
        self.original = None
        self.orig_dim = 224
        
    def forward(self, x):
        # sample s'*s' x and y
        coords = np.random.choice(np.arange(self.orig_dim), 2*self.resize_dim*self.resize_dim)
        coords = coords.reshape((-1, 2))
        self.coordinates = coords
        self.original = x
        subsample = x[0, :, coords[:, 0], coords[:, 1]]
        subsample = subsample.reshape((3, self.resize_dim, self.resize_dim))
        
        return subsample.unsqueeze(0)


class Imagenet_RandomUpsampler(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_space_channels = 3
        self.coordinates = None
        self.original = None
        
    def forward(self, x):
        resize_dim = x.shape[-1]
        if resize_dim == 224:
            # already in image space, pass-through similar to Upsampler
            return x
        
        coords = self.coordinates
        new_x = self.original
        subsample = x[0].reshape((3, resize_dim*resize_dim))
        new_x[0, :, coords[:, 0], coords[:, 1]] = subsample
            
        return new_x
    
    def update_original(self, x):
        x = x.clone().detach()
        if len(x.size()) != 4:
            x = x.unsqueeze(0)
        self.original = x
        
    def update_coordinates(self, coords):
        self.coordinates = coords