#!/usr/bin/env python3

import torch
import torch.nn as nn

def requ(x):
    return torch.relu(x)**2

def requr(x):
    return requ(x)-requ(x-0.5)

class FC(nn.Module):
    def __init__(self,originaldim,embedingdim,AFFNnum):
        super().__init__()
        self.AFFNnum=AFFNnum
        self.AFFNlist=nn.ModuleList()
        self.embeding=nn.Linear(originaldim,embedingdim)
        for i in range(AFFNnum):
            self.AFFNlist.append(FCBlock(embedingdim))
        self.end=nn.Linear(embedingdim,originaldim)
        
    def forward(self,input):
        x=requr(self.embeding(input))
        for i in range(self.AFFNnum):
            x=self.AFFNlist[i](x)
        x=self.end(x)
        return requ(x)+1e-4
  
class FCBlock(nn.Module):
    def __init__(self,embedingdim):
        super().__init__()
        self.linear1=nn.Linear(embedingdim,embedingdim)
        self.linear2=nn.Linear(embedingdim,embedingdim)
        
    def forward(self,x):
        x2=requr(self.linear1(x))
        x=requr(self.linear2(x2))+x
        return x

