import torch
import torch_butterfly
import math
from torch import nn
from torch.nn import functional as F
from torch_butterfly.complex_utils import Real2Complex, Complex2Real
from torch_butterfly.complex_utils import complex_mul
from torch_butterfly.complex_utils import view_as_real, view_as_complex


class KOP1D(nn.Module):
    def __init__(self, in_size, in_ch, out_ch, kernel_size,
            padding=0, stride=1, nblocks=6, warm_start=False,
            K1=None, Kd=None, K2=None):
        super(KOP1D, self).__init__()

        self.in_size = in_size
        self.kernel_size = kernel_size
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.padding = padding
        self.stride = stride
        self.nblocks = nblocks
        self.warm_start = warm_start

        # Parameterize using real convolutional filter weights
        self.w = nn.Conv1d(self.in_ch, self.out_ch, 
            self.kernel_size,
            padding=(self.kernel_size - 1) // 2,
            padding_mode='circular',
            bias=False)

        if not self.warm_start:
            # NOTE trying to set same norm as conv
            scaled_weights = self.w.weight * math.sqrt(2)
            self.w.weight = nn.Parameter(scaled_weights)

        self.n = self.in_size
        w = self.w.weight
        self.padding_d = (self.kernel_size - 1) // 2
        self.log_n = int(math.ceil(math.log2(self.n)))
        self.n_extended = (
            self.n if self.n == 1 << self.log_n else 1 << (self.log_n + 1))

        self.r2c = Real2Complex()
        self.c2r = Complex2Real()

        # Diagonal bias term
        b = nn.Parameter(torch.zeros((self.in_size, 2)))
        self.register_parameter("b", b)

        # TODO try additive architectural params for diagonal
        if warm_start: # TODO try deep fft2d/ifft2d with identity Butterfly
            self.Kd = torch_butterfly.special.fft(
                self.n_extended,
                normalized=False, br_first=False, with_br_perm=False)
            self.K1 = torch_butterfly.special.fft(
                self.n_extended,
                normalized=True, br_first=False, with_br_perm=False)
            self.K2 = torch_butterfly.special.ifft(
                self.n_extended,
                normalized=True, br_first=True, with_br_perm=False)
        else: # TODO
            self.Kd = torch_butterfly.Butterfly(
                    self.in_size, self.in_size, bias=False,
                    increasing_stride=False, complex=True, init='ortho',
                    nblocks=nblocks)
            self.K1 = torch_butterfly.Butterfly(
                    self.in_size, self.in_size, bias=False,
                    increasing_stride=False, complex=True, init='ortho',
                    nblocks=nblocks)
            self.K2 = torch_butterfly.Butterfly(
                    self.in_size, self.in_size, bias=False,
                    increasing_stride=True, complex=True, init='ortho',
                    nblocks=nblocks)

        if K1 and Kd and K2:
            print("TYING!!!")
            self.K1.twiddle = K1.twiddle
            self.Kd.twiddle = Kd.twiddle
            self.K2.twiddle = K2.twiddle

        self.subsample = nn.AvgPool1d(kernel_size=1, stride=self.stride)

        if not warm_start:
            with torch.no_grad():
                scale_kd = self.in_size
                ts = math.sqrt(scale_kd) ** (1.0 / self.Kd.twiddle.shape[1] / self.Kd.twiddle.shape[2])
                self.Kd.twiddle.mul_(ts)

    def forward(self, x):

        # Check input size - does it match in_size? If not, pad it with zeros
        # and unpad it at the end before subsampling
        hpad = 0
        if (x.shape[-1]) != self.in_size:
            hpad = (self.in_size - x.shape[-1]) // 2
            x = F.pad(x, (hpad, hpad), "constant", 0)

        col = F.pad(self.w.weight.flip([-1]),
                    (0, self.n - self.kernel_size)).roll(
                        -self.padding_d, dims=-1)

        if self.n < self.n_extended:
            col_0 = F.pad(col, (0, 2 * ((1 << self.log_n) - self.n)))
            col = torch.cat((col_0, col), dim=-1)

        col_f = self.Kd(self.r2c(col))
        x = self.K1(self.r2c(x))

        x = complex_mul(x.unsqueeze(1), col_f)
        #x = view_as_complex(view_as_real(x) + self.b)
        x = x.sum(dim=2)
        x = self.c2r(self.K2(x))

        # Handle padding
        hp = ((self.kernel_size - 1) // 2) - self.padding + hpad
        x = x[:, :, hp:(self.in_size - hp)]

        # Handle stride
        x = self.subsample(x)

        return x


class KOP2D(nn.Module):
    def __init__(self, in_size, in_ch, out_ch, kernel_size,
            padding=0, stride=1, nblocks=6, warm_start=False,
            K1=None, Kd=None, K2=None):
        super(KOP2D, self).__init__()

        self.in_size = in_size
        self.kernel_size = kernel_size
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.padding = padding
        self.stride = stride
        self.nblocks = nblocks
        self.warm_start = warm_start

        if isinstance(self.in_size, int):
            self.in_size = (self.in_size, self.in_size)

        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, self.kernel_size)

        if isinstance(self.padding, int):
            self.padding = (self.padding, self.padding)

        if isinstance(self.stride, int):
            self.stride = (self.stride, self.stride)

        # Parameterize using real convolutional filter weights
        self.w = nn.Conv2d(self.in_ch, self.out_ch, 
            (self.kernel_size[0], self.kernel_size[1]),
            padding=((self.kernel_size[0] - 1) // 2, 
                (self.kernel_size[1] - 1) // 2),
            padding_mode='circular',
            bias=False)
        #nn.init.kaiming_normal_(self.w.weight)

        if not self.warm_start:
            # NOTE trying to set same norm as conv
            scaled_weights = self.w.weight * math.sqrt(2)
            self.w.weight = nn.Parameter(scaled_weights)

        self.n1 = self.in_size[1]
        self.n2 = self.in_size[0]
        w = self.w.weight
        self.kernel_size2, self.kernel_size1 = w.shape[-2], w.shape[-1]
        self.padding1 = (self.kernel_size1 - 1) // 2
        self.padding2 = (self.kernel_size2 - 1) // 2
        self.log_n1 = int(math.ceil(math.log2(self.n1)))
        self.log_n2 = int(math.ceil(math.log2(self.n2)))
        self.n_extended1 = (
            self.n1 if self.n1 == 1 << self.log_n1 else 1 << (self.log_n1 + 1))
        self.n_extended2 = (
            self.n2 if self.n2 == 1 << self.log_n2 else 1 << (self.log_n2 + 1))

        self.r2c = Real2Complex()
        self.c2r = Complex2Real()

        # Diagonal bias term
        b = nn.Parameter(torch.zeros((self.in_size[0], self.in_size[1], 2)))
        self.register_parameter("b", b)

        # TODO try additive architectural params for diagonal
        if warm_start: # TODO try deep fft2d/ifft2d with identity Butterfly
            self.Kd = torch_butterfly.special.fft2d(
                self.n_extended1, self.n_extended2, 
                normalized=False, br_first=False, with_br_perm=False, 
                flatten=False)
            self.K1 = torch_butterfly.special.fft2d(
                self.n_extended1, self.n_extended2, 
                normalized=True, br_first=False, with_br_perm=False, 
                flatten=False)
            self.K2 = torch_butterfly.special.ifft2d(
                self.n_extended1, self.n_extended2, 
                normalized=True, br_first=True, with_br_perm=False, 
                flatten=False)
        else:
            self.Kd = torch_butterfly.special.TensorProduct(
                torch_butterfly.Butterfly(
                    self.in_size[1], self.in_size[0], bias=False,
                    increasing_stride=False, complex=True, init='ortho',
                    nblocks=nblocks),
                torch_butterfly.Butterfly(
                    self.in_size[1], self.in_size[0], bias=False,
                    increasing_stride=False, complex=True, init='ortho',
                    nblocks=nblocks))
            self.K1 = torch_butterfly.special.TensorProduct(
                torch_butterfly.Butterfly(
                    self.in_size[1], self.in_size[0], bias=False,
                    increasing_stride=False, complex=True, init='ortho',
                    nblocks=nblocks),
                torch_butterfly.Butterfly(
                    self.in_size[1], self.in_size[0], bias=False,
                    increasing_stride=False, complex=True, init='ortho',
                    nblocks=nblocks))
            self.K2 = torch_butterfly.special.TensorProduct(
                torch_butterfly.Butterfly(
                    self.in_size[1], self.in_size[0], bias=False,
                    increasing_stride=True, complex=True, init='ortho',
                    nblocks=nblocks),
                torch_butterfly.Butterfly(
                    self.in_size[1], self.in_size[0], bias=False,
                    increasing_stride=True, complex=True, init='ortho',
                    nblocks=nblocks))

        if K1 and Kd and K2:
            print("TYING!!!")
            self.K1.map1.twiddle = K1.map1.twiddle
            self.K1.map2.twiddle = K1.map2.twiddle
            self.Kd.map1.twiddle = Kd.map1.twiddle
            self.Kd.map2.twiddle = Kd.map2.twiddle
            self.K2.map1.twiddle = K2.map1.twiddle
            self.K2.map2.twiddle = K2.map2.twiddle

        self.subsample = nn.AvgPool2d(kernel_size=(1, 1), stride=self.stride)

        if not warm_start:
            with torch.no_grad():
                scale_kd = self.in_size[0]
                # TODO which scaling factor to use for unequal input sizes?
                ts = math.sqrt(scale_kd) ** (1.0 / self.Kd.map1.twiddle.shape[1] / self.Kd.map1.twiddle.shape[2])
                self.Kd.map1.twiddle.mul_(ts)

                scale_kd = self.in_size[1] # TODO or are these swapped?
                ts = math.sqrt(scale_kd) ** (1.0 / self.Kd.map2.twiddle.shape[1] / self.Kd.map2.twiddle.shape[2])
                self.Kd.map2.twiddle.mul_(ts)

    def forward(self, x):

        # Check input size - does it match in_size? If not, pad it with zeros
        # and unpad it at the end before subsampling
        hpad = 0
        vpad = 0
        if (x.shape[-2], x.shape[-1]) != self.in_size:
            hpad = (self.in_size[0] - x.shape[-2]) // 2
            vpad = (self.in_size[1] - x.shape[-1]) // 2
            x = F.pad(x, (hpad, hpad, vpad, vpad), "constant", 0)

        col = F.pad(self.w.weight.flip([-1]),
                    (0, self.n1 - self.kernel_size1)).roll(
                        -self.padding1, dims=-1)
        col = F.pad(col.flip([-2]),
                    (0, 0, 0, self.n2 - self.kernel_size2)).roll(
                        -self.padding2, dims=-2)

        if self.n1 < self.n_extended1:
            col_0 = F.pad(col, (0, 2 * ((1 << self.log_n1) - self.n1)))
            col = torch.cat((col_0, col), dim=-1)

        if self.n2 < self.n_extended2:
            col_0 = F.pad(col, (0, 0, 0, 2 * ((1 << self.log_n2) - self.n2)))
            col = torch.cat((col_0, col), dim=-2)

        col_f = self.Kd(self.r2c(col))
        x = self.K1(self.r2c(x))

        x = complex_mul(x.unsqueeze(1), col_f)
        # Add b term
        #x = view_as_complex(view_as_real(x) + self.b)
        x = x.sum(dim=2)
        x = self.c2r(self.K2(x))

        # Handle padding
        hp = ((self.kernel_size[0] - 1) // 2) - self.padding[0] + hpad
        vp = ((self.kernel_size[1] - 1) // 2) - self.padding[1] + vpad
        x = x[:, :, hp:(self.in_size[0] - hp), vp:(self.in_size[1] - vp)]

        # Handle stride
        x = self.subsample(x)

        return x
