## stage correction module with FC layer

import torch
import torch.nn as nn
import torch.nn.parallel
import numpy as np
from torch.nn.parameter import Parameter


class Rational(nn.Module):
    def __init__(self):
        super().__init__()
        self.a0 = Parameter(torch.tensor([0.0]))
        self.a1 = Parameter(torch.tensor([0.7]))
        self.a2 = Parameter(torch.tensor([1.0]))
        self.a3 = Parameter(torch.tensor([0.3]))
        self.b0 = Parameter(torch.tensor([-1.0]))
        self.b1 = Parameter(torch.tensor([1.0]))
    def forward(self, x):
        y = (self.a3 * x**3 + self.a2 * x**2 + self.a1 * x + self.a0) / ((self.b1 * x + self.b0)**2 + 1)
        return y


class StageCorrFC(nn.Module):
    
    def __init__(self, n, nhidden = 256):
        super().__init__()
        self.n = n
        self.rftn = self.n //2 + 1
        
        self.net = nn.Sequential(
            nn.Linear(2 * self.rftn, nhidden), 
            Rational(),
            #?nn.GELU(),
            #?nn.Tanh(),
            nn.Linear(nhidden, 2 * self.rftn)
        )
        
    
    def outlen(self, inlen, k, s, d, p):
        num = inlen + 2*p - d*(k-1) - 1
        return int(np.floor(num/s + 1))
    

    def forward(self, v):
        Fv = torch.fft.rfft(v, axis = 1, norm = 'ortho')
        flat = torch.concatenate((torch.real(Fv), torch.imag(Fv)), axis = 1)
        flat = self.net(flat)
        Fv = flat[:, :self.rftn] + 1j * flat[:, self.rftn:]
        v = torch.fft.irfft(Fv, n = self.n, norm = 'ortho')
        return v