import torch
import math
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from .Attention_Module import AttentionLayer, Attention
from torch.nn.parameter import Parameter

class Binary_activation(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return (torch.softmax(input, dim=-1) >= 1./input.shape[-1]).float()

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output
        grad_input = grad_input.clamp_(-1, 1)
        return grad_input

class Frequency_Mask_Router(nn.Module):
    def __init__(self, router_num, window_size):
        super().__init__()
        mask_dim = window_size//2 + 1
        self.window_size = window_size
        self.Learnable_Router = Parameter(torch.empty(1, router_num, mask_dim))
        self.Mask_Generator = AttentionLayer(Attention(window_size, False), d_model=mask_dim, d_keys=64, n_heads=8)
        self.activation = Binary_activation.apply
        init.kaiming_uniform_(self.Learnable_Router, a=math.sqrt(5))
        self.weight_init()
    
    def weight_init(self):
         for m in self.Mask_Generator.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        
    def forward(self, x:torch.tensor): 
        x = x.permute(0,2,1)
        # convert time-domain to freq-domain
        freqs = torch.fft.rfft(x, dim=-1) 
        router = self.Learnable_Router.expand(x.shape[0],-1,-1)
        power = torch.abs(freqs) ** 2
        mask, _ = self.Mask_Generator(power, router, router, attn_mask=None)
        mask = self.activation(mask)
        masked_freq = freqs * mask
        unmasked_freq = freqs * (1 - mask)
        # convert freq-domain to time-domain
        x_t1 = torch.fft.irfft(masked_freq, n=self.window_size, dim=-1)
        x_t2 = torch.fft.irfft(unmasked_freq, n=self.window_size, dim=-1)
        return x_t1.permute(0,2,1), x_t2.permute(0,2,1), mask.permute(0,2,1)

if __name__ == '__main__':
    x = torch.randn(32, 100 , 38)
    F_mask = Frequency_Mask_Router(router_num=8, window_size=x.shape[1])
    x_t, mask = F_mask(x)
    print(x_t.shape, '\t', mask.shape)




