import numpy as np
import torch
import torch.fft
from utils.config import *

args = parser.parse_args()
# System and initial condition.
# All of the code which use system name or specified initial condition should pass via this list.
SYSTEM = ['2D_cdr', '2D_cdr_SR']     # , 'Helmholtz'

if args.spatial_size == "pi":
    SPATIAL_SIZE = math.pi ; COEFF = 2
elif args.spatial_size == "1":
    SPATIAL_SIZE = 1 ; COEFF = 2*math.pi
else:
    SPATIAL_SIZE = 2*math.pi ; COEFF = 1


#### Define initial condition function ####
def function(u0: str):
    if u0 == 'sin':
        u0 = lambda x, y: 1 + torch.sin(COEFF*x)*torch.sin(COEFF*y)
    else:
        u0 = lambda x, y: 1
    return u0

#### Define reaction terms ####
def reaction(u, rho, epsilon, theta, dt):
    # du/dt = rho * u * (1 - u) + epsilon * (u - u^3) + theta * u^2 * (1 - u) 
    def du_dt(u):
        return rho * u * (1 - u) + epsilon * (u - u**3) + theta * u**2 * (1 - u) 
    
    # Central difference method: average the derivative at both current and next steps
    u = u + 0.5 * (du_dt(u) + du_dt(u + du_dt(u) * dt)) * dt
    return u

#### Define convection, diffusion terms ####
def convection_diffusion(u, beta, beta_y, nu, nu_y, dt, IKX, IKX2, IKY, IKY2):
    # du/dt = -beta*du/dx -beta_y*du/dy + nu*d2u/dx2 + nu_y*d2u/dy2
    u_hat = np.fft.fft2(u)
    
    x_factor = nu * IKX2 * dt - beta * IKX * dt
    y_factor = nu_y * IKY2 * dt - beta_y * IKY * dt
    
    u_hat *= np.exp(x_factor+y_factor)
    u = np.real(np.fft.ifft2(u_hat))
    
    return u


def convection_diffusion_reaction_discrete_solution_2D(u0 : str, beta, beta_y, nu, nu_y, rho, epsilon, theta, 
                                                       nx = 32, ny = 32, nt = 101, total_nt = 1001):
    Lx = SPATIAL_SIZE
    Ly = SPATIAL_SIZE
    T = args.max_time
    dx = Lx/(nx-1)
    dy = Ly/(ny-1)
    dt = T/(total_nt-1)
    x = np.linspace(0, Ly, ny) # not inclusive of the last point
    y = np.linspace(0, Ly, ny) # not inclusive of the last point
    t = np.linspace(0, T, total_nt).reshape(-1, 1)
    X, Y, T = np.meshgrid(x, y, t)
    u = np.zeros((nx, ny, nt))
    sampling = (total_nt-1)/(nt-1)

    KX = 1j * np.fft.fftfreq(nx, d=dx) * 2 * np.pi
    KY = 1j * np.fft.fftfreq(ny, d=dy) * 2 * np.pi
    
    IKX, IKY = np.meshgrid(KX, KY)
    IKX2, IKY2 = IKX**2, IKY**2
    
    # call u0 this way so array is (n, ), so each row of u should also be (n, )
    u0 = function(u0)
    u0 = u0(torch.tensor(X[:,:,0]), torch.tensor(Y[:,:,0])).numpy()
    
    u[:,:,0] = u0
    u_ = u0
    for i in range(1, total_nt):
        u_ = reaction(u_, rho, epsilon, theta, dt)
        u_ = convection_diffusion(u_, beta, beta_y, nu, nu_y, dt, IKX, IKX2, IKY, IKY2)
        if i%sampling == 0:
            u[:,:,int(i//sampling)] = u_
    
    # u = np.transpose(u,(1,0,2))
    # u = u.T
    u_vals = u.flatten()
    return u_vals, u