import torch
import math
import torch.nn as nn

class SINOmodel(nn.Module):
    def __init__(self, input_dim=1, channel=16, k_num=4, output_dim=1, dt=0.005, step_forward='rk4', anti_alias_ratio=2.0/3):
        super().__init__()
        self.channel = channel
        self.k_layer = nn.Sequential(
            nn.Linear(2, self.channel),  
            nn.ReLU(),         
            nn.Linear(self.channel, self.channel), 
            nn.ReLU(),
            nn.Linear(self.channel, k_num)
        )
        
        self.l1 = nn.Linear((k_num + 1) * input_dim, self.channel, dtype=torch.complex64)
        self.l2 = nn.Linear((k_num + 1) * input_dim, self.channel, dtype=torch.complex64)
        self.l3 = nn.Linear((k_num + 1) * input_dim, self.channel, dtype=torch.complex64)
        self.conv11 = nn.Conv2d(self.channel, output_dim, 1, 1)
        self.dt = dt

        self.f_layer = nn.Sequential(
            nn.Linear(2, 20),  
            nn.ReLU(),         
            nn.Linear(20, 20), 
            nn.ReLU(),
            nn.Linear(20, 1)  
        )
        
        self.anti_alias_ratio = anti_alias_ratio 
        self.step_forward = step_forward

        

    def compute_features(self, x_h, xg):
        b, c, h, w = x_h.shape
        n = xg.shape[0]  
        xg_expanded = xg.unsqueeze(0).unsqueeze(1)  
        x_h_expanded = x_h.unsqueeze(2)  
        product = x_h_expanded * xg_expanded
        product_reshaped = product.contiguous().view(b, c * n, h, w)
        xp = torch.cat([x_h, product_reshaped], dim=1) 
        return xp

    def f_pi(self, x):
        x=x.permute(0,3,1,2)
        device=x.device
        N=x.shape[2]
        k_max=math.floor(N/2.0)
        k_y=torch.cat((torch.arange(0,k_max,device=device), torch.arange(-k_max,0,device=device)),0).repeat(N, 1).float()
        k_x=k_y.transpose(0,1)
        k_x=k_x[...,:k_max+1]
        k_y=k_y[...,:k_max+1]
        t=torch.linspace(0,1,N+1,device=device)
        t=t[0:-1]        
        X,Y=torch.meshgrid(t,t,indexing='ij')
        force=torch.concat([X.unsqueeze(0),Y.unsqueeze(0)],dim=0)
        force=self.f_layer(force.permute(1,2,0)).permute(2,0,1)
        xg=torch.cat([k_x.unsqueeze(0),k_y.unsqueeze(0)],dim=0).permute(1,2,0)
        xg=self.k_layer(xg).permute(2,0,1)
        x_h=torch.fft.rfft2(x)
        x_tp=self.compute_features(x_h,xg)
        x_tp1=self.l1(x_tp.permute(0,2,3,1)).permute(0,3,1,2)
        x_tp2=self.l2(x_tp.permute(0,2,3,1)).permute(0,3,1,2)
        x_tp3=self.l3(x_tp.permute(0,2,3,1)).permute(0,3,1,2)
        dealias=torch.unsqueeze(torch.logical_and(torch.abs(k_y)<=self.anti_alias_ratio*k_max,torch.abs(k_x)<=self.anti_alias_ratio*k_max).float(),0)
        x_tp1=torch.fft.irfft2(x_tp1,s=(N,N))
        x_tp2=torch.fft.irfft2(x_tp2,s=(N,N))
        x_tp3=torch.fft.irfft2(x_tp3,s=(N,N))
        x_f=torch.fft.rfft2(force,s=(N,N))
        r=self.conv11(torch.fft.irfft2(torch.fft.rfft2(x_tp1*x_tp2)*dealias+x_f,s=(N,N))+x_tp3)       
        return r.permute(0,2,3,1)

    def forward(self, u, step_forward=None, dt=None):
        if step_forward is None:
            step_forward = self.step_forward
        if dt is None:
            dt = self.dt

        if step_forward == 'rk4':
            k1 = self.f_pi(u)
            k2 = self.f_pi(u + dt * k1 / 2)
            k3 = self.f_pi(u + dt * k2 / 2)
            k4 = self.f_pi(u + dt * k3)
            u = u + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
        return u
