from typing import Optional, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as vF
from math import sqrt, ceil

from torch import Tensor

from modules.complexPyTorch.complexLayers import (ComplexSoftmax2d,
                                                  ComplexConv2d,
                                                  ComplexGELU,
                                                  ComplexBatchNorm2d, ComplexAdaptiveAvgPool2d, ComplexAvgPool2d)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """
    1x1 convolution
    """
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """
    3x3 convolution with padding
    """
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


class ResidualBottleneck(nn.Module):
    # ResNet ResidualBottleneck block modified from torch implementation

    def __init__(
            self,
            inplanes: int,
            planes: int,
            stride: int = 1,
            groups: int = 1,
            base_width: int = 64,
            dilation: int = 1,
            norm_layer: Optional[Callable[..., nn.Module]] = None,
            expansion=4,
            disable_output_expansion=False
    ) -> None:
        super().__init__()
        self.expansion = expansion
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)

        self.disable_output_expansion = disable_output_expansion
        outplanes = planes * (1 if self.disable_output_expansion else expansion)

        self.conv3 = conv1x1(width, outplanes)
        self.bn3 = norm_layer(outplanes)
        self.relu = nn.GELU()
        self.downsample = nn.Sequential()
        if stride != 1 or inplanes != outplanes:
            self.downsample = nn.Sequential(
                conv1x1(inplanes, outplanes, stride),
                norm_layer(outplanes),
            )
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class CenterCrop(nn.Module):
    """
    Crop the tensor from the center
    """
    def __init__(self, crop_size):
        super().__init__()
        self.crop_size = crop_size

    def forward(self, x):
        return vF.center_crop(x, self.crop_size) #doesn't work with torch.fx


class DropCrop(nn.Module):
    """
    Drop central portion and crop outer part of the tensor.

    This can be used for multi-resolution analysis in discrete Fourier space.

    Given size of drop region, zero out this small region from the centre.
    Given size of the crop region, select only the sub-region of this size from the centre.

    r2c - real to complex - drops and crops relative to real FFTs
    """
    def __init__(self, drop_size, crop_size, r2c=False, use_clone=True):
        """
        Args:
            drop_size: size of the region to drop from the centre
            crop_size: size of the region to crop from the centre
            r2c: True if using RFFT features
            use_clone: if True, use clone to create a copy of the tensor, otherwise use in-place operations
        """
        super().__init__()
        self.drop_size = drop_size
        self.crop_size = crop_size
        self.r2c = r2c
        self.use_clone = use_clone

    def __str__(self):
        return f"DropCrop(drop_size={self.drop_size}, crop_size={self.crop_size}, r2c={self.r2c}, use_clone={self.use_clone})"

    def forward(self, x):
        #drop
        # _, _, lx, ly = x.shape #channel first
        lx, ly = x.shape[-2], x.shape[-1] #channel first

        if self.use_clone:
            drop_x = x.clone()
        else:
            drop_x = x

        if self.drop_size > 0:
            mid = self.drop_size//2 #find mid point of relevant region
            midx = int(lx/2.0+0.5)
            midy = int(ly/2.0+0.5)
            if self.r2c: # in rfft, only N//2+1 coefficients are kept in y
                midy = (ly-1)//2

            #get the top and bottom region interval
            newLengthX1 = midx - mid
            if self.drop_size % 2 == 1: #odd size
                newLengthX2 = midx + mid + 1
            else:
                newLengthX2 = midx + mid

            #get the right and left region interval
            drop_size = self.drop_size
            if self.r2c: # in rfft, only N//2+1 coefficients are kept in y
                drop_size = self.drop_size//2+1 #actual region length for r2c
                mid = int((drop_size-1)/2.0+0.5)

            newLengthY1 = midy - mid
            if drop_size % 2 == 1: #odd size
                newLengthY2 = midy + mid + 1
                # if self.r2c:
                #     newLengthY1 += 1
            else:
                newLengthY2 = midy + mid

            # print("drop: ", newLengthX1, newLengthX2, newLengthY1, newLengthY2)
            drop_x[..., newLengthX1:newLengthX2, newLengthY1:newLengthY2] = 0

        #crop
        if self.crop_size == lx:
            return drop_x #no crop required
        else:
            # print("crop shape:", x.shape)
            mid = self.crop_size//2 #find mid point of relevant region
            midx = int(lx/2.0+0.5)
            midy = int(ly/2.0+0.5)
            if self.r2c: # in rfft, only N//2+1 coefficients are kept in y for length N
                midy = (ly-1)//2

            newLengthX1 = midx - mid
            if self.crop_size % 2 == 1: #odd size
                newLengthX2 = midx + mid + 1
            else:
                newLengthX2 = midx + mid

            crop_size = self.crop_size
            if self.r2c: # in rfft, only N//2+1 coefficients are kept in y for length N
                crop_size = self.crop_size//2+1 #actual region length for r2c
                mid = int((crop_size-1)/2.0+0.5)

            newLengthY1 = midy - mid
            if crop_size % 2 == 1: #odd size
                newLengthY2 = midy + mid + 1
                # if self.r2c:
                #     newLengthY1 -= 1
            else:
                newLengthY2 = midy + mid

            return drop_x[..., newLengthX1:newLengthX2, newLengthY1:newLengthY2]

def cconv3x3(in_planes, out_planes, stride=1):
    "complex 3x3 convolution with padding"
    return ComplexConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)


class DWConv2d(nn.Module):
    '''
    Modified from https://github.com/seungjunlee96/Depthwise-Separable-Convolution_Pytorch
    '''

    def __init__(self, in_dim, dim, kernels_per_layer, kernel_size, padding):
        super().__init__()
        self.depthwise = nn.Conv2d(in_dim, in_dim * kernels_per_layer,
                                   kernel_size=kernel_size,
                                   padding=padding,
                                   groups=in_dim,
                                   bias=False)
        self.pointwise = nn.Conv2d(in_dim * kernels_per_layer,
                                   dim,
                                   kernel_size=1,
                                   bias=False)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class PhasorBlockC(nn.Module):
    """
    Phasor Block (C), which generates additional complex-valued features
    for complex-valued input.

    """

    def __init__(self,
                 d_in,
                 d_out,
                 depth=2,
                 stride=1,
                 dw_kernel_size=3,
                 dw_padding=1,
                 dw_kernels_per_layer=3):
        super().__init__()

        self.d_in = d_in
        self.d_out = d_out
        self.depth = depth
        self.stride = stride

        # DW conv setup
        self.dw_kernel_size = dw_kernel_size
        self.dw_padding = dw_padding
        self.dw_kernels_per_layer = dw_kernels_per_layer

        # real component feature extraction
        self.real_feature_layers = nn.ModuleList()
        for i in range(depth):
            d_in_layer = self.d_in if i == 0 else d_out
            stride_layer = self.stride if i == 0 else 1

            block = nn.Sequential(
                nn.Conv2d(d_in_layer,
                          d_out,
                          kernel_size=3,
                          stride=stride_layer,
                          bias=False,
                          padding=1),
                nn.BatchNorm2d(d_out),
                nn.GELU()
            )
            self.real_feature_layers.append(block)

        # imaginary component feature extraction
        self.imag_feature_layers = nn.ModuleList()
        for i in range(depth):
            block = nn.Sequential(
                DWConv2d(d_out,
                         d_out,
                         kernel_size=self.dw_kernel_size,
                         padding=self.dw_padding,
                         kernels_per_layer=self.dw_kernels_per_layer),
                nn.BatchNorm2d(d_out),
                nn.GELU()
            )
            self.imag_feature_layers.append(block)

        # shortcut
        self.shortcut = ComplexConv2d(d_in, d_out,
                                      kernel_size=1,
                                      stride=1,
                                      padding=0,
                                      bias=False)
        if self.stride == 1:
            self.shortcut_downsample = nn.Identity()
        elif self.stride == 2:
            self.shortcut_downsample = ComplexAvgPool2d(kernel_size=2, stride=2)
        else:
            raise ValueError(f"Invalid stride value: {self.stride}")

        self.sbn = ComplexBatchNorm2d(d_out)
        self.sgelu = ComplexGELU(phase_amp=False)

    def forward(self, x):
        # separate to mag/phase for efficiency
        axx = torch.real(x)

        # real feature layers
        for layer in self.real_feature_layers:
            axx = layer(axx)

        # complex feature layers
        ppx = axx
        for layer in self.imag_feature_layers:
            ppx = layer(ppx)

        # combine new real and imaginary features
        cx = axx + 1j * ppx.type(torch.complex64)

        # shortcut
        sx = self.shortcut(x)
        sx = self.sbn(sx)
        sx = self.sgelu(sx)
        sx = self.shortcut_downsample(sx)
        cx = sx + cx

        return cx


class ComplexHadamard(nn.Module):
    '''
    Attention module based on a complex-valued Hadamard product (no sum), channels must match
    '''
    def __init__(self, N, dim, filter_init=True, no_activation=False, quantiser=False, softshrink=None, r2c=False,
                 drop_center=None, init_type='weibull', init_scale=0.02):
        super().__init__()

        self.N = N
        self.nmax = N/2.
        self.N_r2c = N
        self.dim = dim
        self.filter_init = filter_init
        self.no_activation = no_activation
        self.quantiser = quantiser
        self.softshrink = softshrink
        self.r2c = r2c
        self.init_type = init_type
        self.init_scale = init_scale

        if self.r2c:
            self.N_r2c = self.N//2+1

        self.drop_center = drop_center
        if self.drop_center is not None:
            self.drop_center = DropCrop(drop_center, N, r2c=self.r2c)

        self.quant = ComplexSoftmax2d() #quantise to 0-1 mutually exclusively
        self.act = ComplexGELU(phase_amp=False)

        #torch autp mixed precision (torch.amp) doesn't support complex type, use view as complex
        #see also https://github.com/erksch/fnet-pytorch/pull/10 
        self.weights = nn.Parameter(torch.zeros((self.dim, self.N, self.N_r2c, 2), dtype=torch.float32))

        # Weibull init (arg 1: lambda, arg 2: k/shape), 
        # https://en.wikipedia.org/wiki/Weibull_distribution
        # shape=2. is equivalent to the Rayleigh dist
        # lambda=sqrt(2)/sqrt(n_in)
        with torch.no_grad(): #init with Rayleigh dist
            if self.init_type == 'weibull':
                lambda_parameter=sqrt(2)/sqrt(self.N*self.N_r2c)
                # print("lambda_parameter:",lambda_parameter)
                pdf = torch.distributions.weibull.Weibull(lambda_parameter, 2.0)
                self.weights[...,0] = pdf.sample((self.dim, self.N, self.N_r2c))
                self.weights[...,1] = pdf.sample((self.dim, self.N, self.N_r2c))
            elif self.init_type == 'normal':
                print("init_scale:", self.init_scale)
                self.weights = nn.Parameter(torch.randn(self.dim, self.N, self.N_r2c, 2, dtype=torch.float32) * self.init_scale)
                print(torch.min(self.weights), torch.max(self.weights))

        if self.filter_init:
            with torch.no_grad(): #custom init with band pass filter
                x = torch.arange(-1.0, 1.0, 1./self.nmax)
                y = torch.arange(-1.0, 1.0, 1./(self.N_r2c/2))
                xx, yy = torch.meshgrid(x, y)
                filter = torch.sqrt(xx**2 + yy**2) #Ram-Lak
                filter *= torch.sin(filter)/(filter+1e-8) #Shepp-Logan
                self.weights[...,0] += filter
                self.weights[...,1] += filter

    def forward(self, x):
        weights = torch.view_as_complex(self.weights) #for torch.amp support

        if self.drop_center is not None:
            weights = self.drop_center(weights)

        if self.quantiser:
            weights = self.quant(weights)
        elif not self.softshrink is None: #force weights to be sparse
            w_r = F.softshrink(torch.real(weights), lambd=self.softshrink)
            w_i = F.softshrink(torch.imag(weights), lambd=self.softshrink)
            weights = w_r + 1j*w_i

        hadamard = torch.einsum('bkij, kij -> bkij', x, weights)

        if self.no_activation:
            return hadamard
        else:
            return self.act(hadamard)

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'N={}, d_out={}'.format(
            self.N, self.dim
        )


class HadamardBlock(nn.Module):

    def __init__(self,
                 in_dim,
                 dim,
                 scale,
                 act_layer=ComplexGELU(phase_amp=False),
                 norm_layer=ComplexBatchNorm2d,
                 r2c=False,
                 drop_center=None,
                 quantiser=True,
                 filter_init_type=None,
                 filter_init_scale=None):
        super().__init__()
        self.norm1 = norm_layer(in_dim)
        self.filter = ComplexHadamard(scale, in_dim, filter_init=False, no_activation=True, quantiser=quantiser,
                                      softshrink=None, r2c=r2c, drop_center=drop_center,
                                      init_type=filter_init_type,
                                      init_scale=filter_init_scale)  # learnt filtering
        self.pool = ComplexConv2d(in_dim, dim, kernel_size=1, stride=1, padding='same', bias=False)  # mix channels
        self.norm2 = norm_layer(dim)
        self.act_layer = act_layer


    def forward(self, x):
        out = self.norm2(self.pool(self.filter(self.norm1(x))))

        return self.act_layer(out)


class SpectralBranches(nn.Module):
    def __init__(self,
                 N,
                 subdepth,
                 spectral_branch_config,
                 hadamard_norm_layer=ComplexBatchNorm2d,
                 filter_init_type='weibull',
                 filter_init_scale=0.02):
        """
        N: overall spatial size of filter weights
        embed_dim: number of channels for filter weights
        spectral_branch_config: spectral branch config from .yaml configs
        filter_init_type: type of filter initialization
        filter_init_scale: scale of filter initialization for normal distro

        """
        super().__init__()
        self.N = N
        self.subdepth = subdepth
        self.spectral_branch_config = spectral_branch_config

        self.branches = nn.ModuleList()
        self.output_dims = []
        for (crop_size, drop_size, d_in, d_out) in self.spectral_branch_config:
                # crop_size and drop_size both none disables dropcrop and hadamard
                if crop_size is None and drop_size is None:
                    # no dropcrop or Hadamard filter, just bn -> 1x1 -> bn
                    dropcrop = nn.Identity()
                    block = nn.Sequential(
                        ComplexBatchNorm2d(d_in * self.subdepth),
                        ComplexConv2d(d_in * self.subdepth, d_out * self.subdepth, kernel_size=1, stride=1, padding='same', bias=False),
                        ComplexBatchNorm2d(d_out * self.subdepth)
                    )
                else:
                    dropcrop = DropCrop(crop_size=crop_size, drop_size=drop_size)
                    block = HadamardBlock(in_dim=d_in*self.subdepth,
                                          dim=d_out*self.subdepth,
                                          scale=crop_size,
                                          norm_layer=hadamard_norm_layer,
                                          filter_init_type=filter_init_type,
                                          filter_init_scale=filter_init_scale,)
                ave_pool = ComplexAdaptiveAvgPool2d(1)
                self.branches.append(nn.Sequential(
                    dropcrop,
                    block,
                    ave_pool
                ))
                self.output_dims.append(d_out)

    def forward(self, x):
        # cast x to complex if real
        if torch.is_floating_point(x):
            x = x.to(torch.complex64)

        branch_outputs = []
        for branch in self.branches:
            branch_outputs.append(branch(x))

        return branch_outputs
