
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from .fft_convolutions import FFTConv2d

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class Down_Sampling(nn.Module):
    def __init__(self, stride=1):
        super().__init__()
        self.stride = stride

    def forward(self, x):
        # downsampling an image x
        stride = self.stride
        xs = x[..., ::stride, ::stride]
        return xs


class FFT_Padding(nn.Module):
    def __init__(self, conv):
        super().__init__()
        # initialize the associated conv layer
        self.conv = conv

    def forward(self, X):
        Y = fft_input_padding(self.conv, X)
        return Y


class FFT_Cropping(nn.Module):
    def __init__(self, conv, input_image):
        super().__init__()
        self.conv = conv
        self.X = input_image

    def forward(self, X_pad):
        X_crop = fft_output_crop(self.conv, self.X, X_pad)
        return X_crop

class Conv_Post_Processing(nn.Module):
    def __init__(self, conv, input_image, stride=1, b=None):
        super().__init__()
        self.conv = conv
        self.X = input_image
        self.output_init = conv(input_image)
        self.stride = stride
        self.b = b

        self.crop = FFT_Cropping(self.conv, self.X)
        self.downsample = Down_Sampling(self.stride)
        self.bias = Bias(self.b)

    def forward(self, x):
        x = self.crop(x)
        x = self.downsample(x)
        x = self.bias(x)
        return x


class Bias(nn.Module):
    def __init__(self, b):
        super().__init__()
        self.b = b

    def forward(self, Y):
        if self.b is None:
            return Y
        else:
            bias_tensor = generate_bias_tensor(self.b, Y)
            return Y + bias_tensor

class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

# =============================================================================
#   FFT operations for convolution layers
# =============================================================================

def fft_input_padding(conv, X):
    '''output = padded X'''
    # inch = conv.in_channels
    # outch = conv.out_channels
    kernel_size = conv.kernel_size[0]
    padding = conv.padding[0]

    k = kernel_size // 2
    # reduced padding when forward convolution shrinks in size
    r = (k - padding)
    # extra padding for even kernel
    even = (kernel_size % 2 == 0)
    tr = lambda i: even * (i % 2 == 0)
    X_pad = F.pad(X, tuple(2 * k - r + tr(_) for _ in range(4)))

    return X_pad

def fft_output_padding(conv, Y):
    '''Y = conv(X), output = padded Y'''
    kernel_size = conv.kernel_size[0]
    k = kernel_size // 2
    Y_pad = F.pad(Y, tuple(2 * k for _ in range(4)))
    return Y_pad

def fft_output_crop(conv, X, Y_pad):
    '''Y_pad = fft_conv(X_pad), output = cropped Y_pad'''
    diff = (Y_pad.size(-1) - conv(X).size(-1))
    k = diff // 2
    offset = diff % 2  # 1 if non-even padding, zero otherwise
    n = Y_pad.size(-1) - offset
    Y_crop = Y_pad[:, :, k:n - k, k:n - k]
    return Y_crop

def find_output_cropping_tuple(Y, Y_pad):
    diff = (Y_pad.size(-1) - Y.size(-1))
    k = diff // 2
    offset = diff % 2  # 1 if non-even padding, zero otherwise
    n = Y_pad.size(-1) - offset
    return n, k

def fft_conv_X_pad_crop(conv, X_pad):
    '''X_pad = padded X, output = original X'''
    kernel_size = conv.kernel_size[0]
    padding = conv.padding[0]

    k = kernel_size // 2
    # reduced padding when forward convolution shrinks in size
    r = (k - padding)
    # extra padding for even kernel
    even = (kernel_size % 2 == 0)
    tr = lambda i: even * (i % 2 == 0)
    pad_tuple = tuple(2 * k - r + tr(_) for _ in range(4))
    X_recover = X_pad[:, :, pad_tuple[2]:-pad_tuple[3], pad_tuple[0]:-pad_tuple[1]]
    return X_recover

def find_input_padding_tuple(conv):
    kernel_size = conv.kernel_size[0]
    padding = conv.padding[0]

    '''Attention: what if padding = 0?'''
    # assert padding == 1

    k = kernel_size // 2
    # reduced padding when forward convolution shrinks in size
    r = (k - padding)
    # extra padding for even kernel
    even = (kernel_size % 2 == 0)
    tr = lambda i: even * (i % 2 == 0)
    padding_tuple = tuple(2 * k - r + tr(_) for _ in range(4))
    return padding_tuple

def fft_conv_forward_map(fft_conv, X):
    '''compute forward pass of conv layer through fft'''
    X_pad = fft_input_padding(fft_conv.conv, X)
    Y_pad = fft_conv(X_pad)
    Y = fft_output_crop(fft_conv.conv, X, Y_pad)
    Y_conv = fft_conv.conv(X)

    err = (Y - Y_conv).abs().max().item()
    # assert err < 1e-5
    return Y


def generate_bias_tensor(b, Y):
    if b is None:
        bias_tensor = torch.zeros(Y.size()).to(Y.device)
    else:
        num_channels = b.size(0)
        assert num_channels == Y.size(1)

        bias_tensor = torch.ones(Y.size()).to(b.device)

        for i in range(num_channels):
            bias_tensor[:, i, :, :] = b[i] * torch.ones(Y.size(-2), Y.size(-1)).to(b.device)

    return bias_tensor

