## stage correction module with FC layer

import torch
import torch.nn as nn
import torch.nn.parallel
from torch.nn.parameter import Parameter


class Rational(nn.Module):
    def __init__(self):
        super().__init__()
        self.a0 = Parameter(torch.tensor([0.0]))
        self.a1 = Parameter(torch.tensor([0.7]))
        self.a2 = Parameter(torch.tensor([1.0]))
        self.a3 = Parameter(torch.tensor([0.3]))
        self.b0 = Parameter(torch.tensor([-1.0]))
        self.b1 = Parameter(torch.tensor([1.0]))
    def forward(self, x):
        y = (self.a3 * x**3 + self.a2 * x**2 + self.a1 * x + self.a0) / ((self.b1 * x + self.b0)**2 + 1)
        return y


class StageCorrFC(nn.Module):
    
    def __init__(self, n, nhidden = 256):
        super().__init__()
        self.n = n
        
        self.net = nn.Sequential(
            nn.Linear(n, nhidden), 
            Rational(),
            #?nn.GELU(),
            #?nn.Tanh(),
            nn.Linear(nhidden, n)
        )
    

    def forward(self, v):
        return self.net(v)