import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from src.utils.train import OptimModule

# Fused FFT Conv
try:
    from src.ops.fftconv import fftconv_ref, fftconv_func 
except ImportError:
    fftconv_func = None
    
    
# reference convolution with residual connection
def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None):
    
    # Pad the signal to accomdate kernel length (Adapated for different kernel sizes)  
    seqlen = u.shape[-1]
    klen = k.shape[-1]
    fft_size = klen + seqlen
    
    k_f = torch.fft.fft(k.float(), n=fft_size) / fft_size
    if k_rev is not None:
        k_rev_f = torch.fft.fft(k_rev.float(), n=fft_size) / fft_size
        k_f = k_f + k_rev_f.conj()
    u_f = torch.fft.fft(u.float(), n=fft_size)
    
    if len(u.shape) > 3: k_f = k_f.unsqueeze(1)

    y = torch.fft.ifft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen].real.to(u.dtype)

    out = y + u * D.unsqueeze(-1)
    if gelu:
        out = F.gelu(out)
    if dropout_mask is not None:
        return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype)
    else:
        return out.to(dtype=u.dtype)


class DenseFilter(OptimModule):
    def __init__(
        self,
        d_model,
        kernel_size,
        channels=1,
        bidirectional=False,
        bias=False,
        k_learning_rate=None,
        k_weight_decay=0.0,
        k_dense_dropout=0.,
        use_fused_fft_conv=False,
        use_bn=True,
        **kwargs
    ):
        super().__init__()
        self.channels = channels
        self.bidirectional = bidirectional
        self.use_fused_fft_conv = use_fused_fft_conv
        
        # Kernel weight
        if self.bidirectional:
            weight = torch.empty(2, d_model * channels, kernel_size)
        else:
            weight = torch.empty(d_model * channels, kernel_size)
        self.weight = nn.Parameter(nn.init.kaiming_uniform_(weight, a=math.sqrt(5)))
        if k_learning_rate is not None:
            self.register("weight", self.weight, k_learning_rate, k_weight_decay)
            
        # Kernel dropout
        self.k_dropout = nn.Dropout(p=k_dense_dropout) if k_dense_dropout is not None else nn.Identity()
        
        # Bias
        self.bias = nn.Parameter(torch.randn(d_model)) if bias else torch.zeros(d_model * self.channels)
        
        # BatchNorm
        self.bn = nn.BatchNorm1d(d_model * channels) if use_bn else nn.Identity()
        
    def get_kernel(self):
        return self.weight
    
    def forward(self, x, **kwargs):
        
        seq_len = x.shape[-1]
        x = x.repeat(1, self.channels, 1)
        
        # Kernel
        k = self.get_kernel()
        k = F.pad(k, (0, seq_len-k.shape[-1]))
        k = self.k_dropout(k)
        bias = self.bias.to(x.device)
        
        # Convolution
        if self.use_fused_fft_conv:
            if self.bidirectional:
                k_forward, k_reverse = k
                y = fftconv_func(x, k_forward, bias, k_rev=k_reverse, dropout_mask=None, gelu=False, force_fp16_output=torch.is_autocast_enabled())
            else:
                y = fftconv_func(x, k, bias, dropout_mask=None, gelu=False, force_fp16_output=torch.is_autocast_enabled())
        else:
            if self.bidirectional:
                k_forward, k_reverse = k
                y = fftconv_ref(x, k_forward, bias, dropout_mask=None, gelu=False, k_rev=k_reverse)
            else:
                y = fftconv_ref(x, k[0], bias, dropout_mask=None, gelu=False)
                
        # BatchNorm
        y = self.bn(y)
            
        return rearrange(y, 'b (c d) l -> b c d l', c=self.channels)
    
    
class DilatedFilter(OptimModule):
    def __init__(
        self,
        d_model,
        kernel_size,
        channels=1,
        bidirectional=False,
        dilation=1,
        k_learning_rate=None,
        k_weight_decay=0.0,
        **kwargs
    ):
        super().__init__()
        self.d_model = d_model
        self.kernel_size = kernel_size if not bidirectional else 2*kernel_size
        self.pad = (kernel_size - 1) * dilation 
        self.dilation = dilation
        self.channels = channels
        self.bidirectional = bidirectional
        
        # Conv weight
        weight = torch.empty(d_model*channels, 1, self.kernel_size)
        self.weight = nn.Parameter(nn.init.kaiming_uniform_(weight, a=math.sqrt(5)))
        if k_learning_rate is not None:
            self.register("weight", self.weight, k_learning_rate, k_weight_decay)
            
        # BatchNorm
        self.bn = nn.BatchNorm1d(d_model * channels)
        
    def forward(self, x, **kwargs):
        
        # Convolution
        if self.bidirectional:
            x = x.repeat(1, self.channels, 1)
            y = F.conv1d(x, self.weight, padding='same', dilation=self.dilation, groups=self.d_model*self.channels)
        else:
            x = F.pad(x, (self.pad, 0), "constant", 0).repeat(1, self.channels, 1)
            y = F.conv1d(x, self.weight, dilation=self.dilation, groups=self.d_model*self.channels)
            
        # Batchnorm
        y = self.bn(y)
        
        return rearrange(y, 'b (c d) l -> b c d l', c=self.channels)


class FourierFilter(OptimModule):
    def __init__(
        self,
        d_model,
        kernel_size,
        channels=1,
        bidirectional=False,
        reduce_factor=1,
        kernel_length=1024,
        bias=False,
        k_learning_rate=None,
        k_weight_decay=0.0,
        k_fourier_dropout=0.,
        k_init='rand',
        use_fused_fft_conv=False,
        norm_type='batchnorm',
        **kwargs
    ):
        super().__init__()
        self.kernel_length = kernel_length
        self.channels = channels
        self.bidirectional = bidirectional
        self.reduce_factor = reduce_factor
        self.use_fused_fft_conv = use_fused_fft_conv
        self.norm_type = norm_type
        
        # Bidirectional
        self.bi_factor = 2 if self.bidirectional else 1
        
        # Kernel
        self.d_kernel = d_model * self.channels * self.bi_factor // reduce_factor
        if k_init == 'rand':
            kernel_f = torch.fft.rfft(torch.randn(self.d_kernel, kernel_size), norm='forward')
        elif k_init == 'cosine':
            kernel_f = self._cosine_init(self.d_kernel, d_model, kernel_size)
        else:
            raise NotImplementedError
        self.kernel_f = nn.Parameter(torch.view_as_real(kernel_f))
        if k_learning_rate is not None:
            self.register("kernel_f", self.kernel_f, k_learning_rate, k_weight_decay)
        self.k_dropout = nn.Dropout(p=k_fourier_dropout) if k_fourier_dropout is not None else nn.Identity()
        
        # Bias
        self.bias = nn.Parameter(torch.randn(d_model)) if bias else torch.zeros(d_model * self.channels // reduce_factor)
        
        # Normalization
        if norm_type == 'batchnorm':
            self.norm = nn.BatchNorm1d(d_model * self.channels)
        elif norm_type == 'linear_scaling':
            self.norm = nn.Parameter(torch.randn(self.d_kernel, 1))
            
        # Reparameterization params
        self.reparam_kernel = None
        self.reparam_bias = None

    def _cosine_init(self, channels, d_model, kernel_size):
        
        # Define parameters
        Fs = kernel_size                # Sampling frequency
        T = 1 / Fs                      # Sampling interval
        t = torch.arange(0, 1, T)       # Time vector
        A = 1                           # Amplitude of the sinusoid
        phi = 0                         # Phase of the sinusoid
        N = len(t)
        freqs = torch.fft.rfftfreq(N, T)

        # Generate sinusoid in frequency domain
        num_blocks = (channels*d_model) // (kernel_size // 2)
        kernel_list = []
        for b in range(num_blocks):
            for i in range(1, N//2 + 1):
                
                # Generate sinusoid
                spectrum = torch.zeros(N//2 + 1, dtype=torch.complex64)
                spectrum[freqs == i] = A * np.exp(1j * phi)
                spectrum = spectrum + (b + 1) * 0.01 * torch.randn_like(spectrum)
                
                # Normalise spectrum
                signal = torch.fft.irfft(spectrum)
                max_magnitude = torch.max(abs(signal))
                signal = signal / max_magnitude
                spectrum = torch.fft.rfft(signal, norm='forward')
                
                kernel_list.append(spectrum.unsqueeze(0))
        
        kernel_f = torch.cat(kernel_list, dim=0).expand(channels*d_model, kernel_size//2 + 1)[torch.randperm(channels*d_model), :]
        
        return kernel_f
    
    def get_kernel(self, sr_factor=1):
        kernel_length = int(self.kernel_length * sr_factor)
        # if self.kernel_length == 32:
        #     print(sr_factor, self.kernel_length, kernel_length)
        k = torch.fft.irfft(torch.view_as_complex(self.kernel_f), n=kernel_length, norm='forward')
        return k
    
    def reparameterize(self, L):
        """ Fuses convolution and batch norm together into a kernel and bias """
        
        kernel = self.get_kernel()
        kernel = F.pad(kernel, (0, L - kernel.shape[-1]))
        
        # Construct bidirectional kernel
        if self.bidirectional:
            fft_size = 2 * L
            k_forward, k_backward = torch.split(kernel, kernel.shape[0]//2, dim=0)
            k_forward_f = torch.fft.rfft(k_forward, n=fft_size) / fft_size
            k_backward_f = torch.fft.rfft(k_backward, n=fft_size) / fft_size
            k_f = k_forward_f + k_backward_f.conj()
            kernel = torch.fft.irfft(k_f, n=fft_size, norm='forward')
        
        # Fuse BN parameters with conv
        running_mean = self.norm.running_mean
        running_var = self.norm.running_var
        gamma = self.norm.weight
        beta = self.norm.bias
        eps = self.norm.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1)
        
        # Store fused parameters
        self.reparam_kernel = kernel * t
        self.reparam_bias = (beta - running_mean * gamma / std).reshape(-1, 1)
        
    
    def forward(self, x, dropout_mask=None, sr_factor=1):
        
        # Input
        B, D, L = x.shape
        x = x.repeat(1, self.channels, 1)
        x = rearrange(x, 'b (c d) l -> (b c) d l', c=self.reduce_factor)
        
        # Kernel
        k = self.get_kernel(sr_factor=sr_factor)
        k = F.pad(k, (0, L-k.shape[-1]))
        k = self.k_dropout(k)
        bias = self.bias.to(x.device)
        
        # Convolution
        if self.bidirectional:
            k_bi = torch.split(k, k.shape[0]//2, dim=0)
            if self.use_fused_fft_conv:
                y = fftconv_func(x, k_bi[0], bias, k_rev=k_bi[1], dropout_mask=dropout_mask, gelu=False, force_fp16_output=torch.is_autocast_enabled())
            else:
                y = fftconv_ref(x, k_bi[0], bias, dropout_mask=dropout_mask, gelu=False, k_rev=k_bi[1])
        else:
            if self.use_fused_fft_conv:
                y = fftconv_func(x, k, bias, dropout_mask=dropout_mask, gelu=False, force_fp16_output=torch.is_autocast_enabled())
            else:
                y = fftconv_ref(x, k, bias, dropout_mask=dropout_mask, gelu=False)
        y = rearrange(y, '(b c) d l -> b (c d) l', c=self.reduce_factor)
                
        # Norm
        if self.norm_type == 'batchnorm':
            y = self.norm(y)
        elif self.norm_type == 'linear_scaling':
            k = self.norm * k            
            
        return rearrange(y, 'b (c d) l -> b c d l', c=self.channels)
    
    
class DilatedFourierFilter(OptimModule):
    def __init__(
        self,
        d_model,
        kernel_size,
        channels=1,
        bidirectional=False,
        reduce_factor=1,
        kernel_length=1024,
        bias=False,
        k_learning_rate=None,
        k_weight_decay=0.0,
        k_fourier_dropout=0.,
        k_init='rand',
        use_fused_fft_conv=False,
        use_bn=True,
        **kwargs
    ):
        super().__init__()
        self.kernel_size = kernel_size
        self.channels = channels
        self.bidirectional = bidirectional
        self.reduce_factor = reduce_factor
        self.kernel_length = kernel_length
        self.use_fused_fft_conv = use_fused_fft_conv
        self.use_bn = True if self.use_1x1 else use_bn
        self.dilation = kernel_length // kernel_size
        
        # Bidirectional
        self.bi_factor = 2 if self.bidirectional else 1
        
        # Fourier Kernel
        self.d_kernel = d_model * self.channels * self.bi_factor // reduce_factor
        if k_init == 'rand':
            kernel_f = torch.fft.rfft(torch.randn(self.d_kernel, kernel_size), norm='forward')
        elif k_init == 'cosine':
            kernel_f = self._cosine_init(self.d_kernel, d_model, kernel_size)
        else:
            raise NotImplementedError
        self.kernel_f = nn.Parameter(torch.view_as_real(kernel_f))
        if k_learning_rate is not None:
            self.register("kernel_f", self.kernel_f, k_learning_rate, k_weight_decay)
        
        # Dilated Kernel
        x = torch.arange(0, self.d_kernel).repeat(kernel_size)
        y = torch.cat([torch.randperm(kernel_length)[:kernel_size] for i in range(self.d_kernel)])
        self.register_buffer("indices", torch.stack([x, y], dim=0))
        self.kernel_d = nn.Parameter(torch.randn(self.d_kernel, kernel_size))
        if k_learning_rate is not None:
            self.register("kernel_d", self.kernel_d, k_learning_rate, k_weight_decay)
        
        # Scaling
        self.scaling = nn.Parameter(torch.randn(2, self.d_kernel))
        if k_learning_rate is not None:
            self.register("scaling", self.scaling, k_learning_rate, k_weight_decay)
        
        # Bias
        self.bias = nn.Parameter(torch.randn(d_model)) if bias else torch.zeros(d_model * self.channels // reduce_factor)
        
        # Batchnorm
        self.bn = nn.BatchNorm1d(d_model * self.channels) if self.use_bn else nn.Identity()
        
        # Dropout
        self.k_dropout = nn.Dropout(p=k_fourier_dropout) if k_fourier_dropout is not None else nn.Identity()
        
        # Reparameterization params
        self.reparam_kernel = None
        self.reparam_bias = None
        
    def get_kernel(self): 
        k_f = self.scaling[0,:].unsqueeze(-1) * torch.fft.irfft(torch.view_as_complex(self.kernel_f), n=self.kernel_length, norm='forward')
        k_d = torch.zeros_like(k_f).to(k_f.device)
        k_d[self.indices[0], self.indices[1]] = self.kernel_d.flatten()
        k_d = self.scaling[1,:].unsqueeze(-1) * k_d
        return k_f + k_d
    
    def reparameterize(self, L):
        """ Fuses convolution and batch norm together into a kernel and bias """
        
        kernel = self.get_kernel()
        kernel = F.pad(kernel, (0, L - kernel.shape[-1]))
        
        # Construct bidirectional kernel
        if self.bidirectional:
            fft_size = 2 * L
            k_forward, k_backward = torch.split(kernel, kernel.shape[0]//2, dim=0)
            k_forward_f = torch.fft.rfft(k_forward, n=fft_size) / fft_size
            k_backward_f = torch.fft.rfft(k_backward, n=fft_size) / fft_size
            k_f = k_forward_f + k_backward_f.conj()
            kernel = torch.fft.irfft(k_f, n=fft_size, norm='forward')
        
        # Fuse BN parameters with conv
        running_mean = self.norm.running_mean
        running_var = self.norm.running_var
        gamma = self.norm.weight
        beta = self.norm.bias
        eps = self.norm.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1)
        
        # Store fused parameters
        self.reparam_kernel = kernel * t
        self.reparam_bias = (beta - running_mean * gamma / std).reshape(-1, 1)
    
    def forward(self, x, dropout_mask=None, **kwargs):
        
        # Input
        B, D, L = x.shape
        x = x.repeat(1, self.channels, 1)
        x = rearrange(x, 'b (c d) l -> (b c) d l', c=self.reduce_factor)
        
        # Kernel
        k = self.get_kernel()
        k = F.pad(k, (0, L-k.shape[-1]))
        k = self.k_dropout(k)
        bias = self.bias.to(x.device)
        
        # Convolution
        if self.bidirectional:
            k_bi = torch.split(k, k.shape[0]//2, dim=0)
            if self.use_fused_fft_conv:
                y = fftconv_func(x, k_bi[0], bias, k_rev=k_bi[1], dropout_mask=dropout_mask, gelu=False, force_fp16_output=torch.is_autocast_enabled())
            else:
                y = fftconv_ref(x, k_bi[0], bias, dropout_mask=dropout_mask, gelu=False, k_rev=k_bi[1])
        else:
            if self.use_fused_fft_conv:
                y = fftconv_func(x, k, bias, dropout_mask=dropout_mask, gelu=False, force_fp16_output=torch.is_autocast_enabled())
            else:
                y = fftconv_ref(x, k, bias, dropout_mask=dropout_mask, gelu=False)
        y = rearrange(y, '(b c) d l -> b (c d) l', c=self.reduce_factor)
                
        # BatchNorm
        y = self.bn(y)    
            
        return rearrange(y, 'b (c d) l -> b c d l', c=self.channels)
    
