
import numpy as np
import torch
import torch.fft
from config import *

args = parser.parse_args()

# System and initial condition.
SYSTEM = ['convection', 'diffusion', 'reaction_Fisher', 'reaction_Allen_Cahn', 'reaction_Zeldovich',\
    'convection-diffusion', 'reaction-diffusion', 'convection-diffusion-reaction', 'random']
INIT_COND = ['1+sin(x)', 'gauss']

def function(u0: str):

    if u0 == '1+sin(x)':
        u0 = lambda x: 1 + torch.sin(x)
    elif u0 == 'gauss':
        x0 = torch.pi
        sigma = torch.pi / 4
        u0 = lambda x: torch.exp(-torch.pow((x - x0) / sigma, 2) / 2.)
    else:
        u0 = lambda x: 1
    return u0

def system_determinator(beta, nu, rho, epsilon, theta):
    if not args.is_2D:
        if nu == 0 and rho == 0 and epsilon == 0 and theta == 0:
            return "convection"
        elif beta == 0 and rho == 0 and epsilon == 0 and theta == 0:
            return "diffusion"
        elif beta == 0 and epsilon == 0 and theta == 0:
            return "reaction_Fisher"
        elif beta == 0 and rho == 0 and theta == 0:
            return "reaction_Allen_Cahn"
        elif beta == 0 and rho == 0 and epsilon == 0:
            return "reaction_Zeldovich"
        elif rho == 0 and epsilon == 0 and theta == 0:
            return "convection-diffusion"
        elif beta == 0 and epsilon == 0 and theta == 0:
            return "reaction-diffusion"
        elif epsilon == 0 and theta == 0:
            return "convection-diffusion-reaction"
        else:
            return "random"
    else:
        raise ValueError("System not supported.")

def reaction_Fisher(u, rho, dt):
    # du/dt = rho*u*(1-u)
    
    factor_1 = u * np.exp(rho * dt)
    factor_2 = (1 - u)
    u = factor_1 / (factor_2 + factor_1)
    return u

def reaction_Allen_Cahn(u, epsilon, dt):
    # du/dt = - epsilon (u^3 - u)
    
    factor_1 = u**2
    factor_2 = np.exp(-2*epsilon*dt) * (1-u**2)
    u = np.sqrt(factor_1 / (factor_1 + factor_2))
    return u

def reaction_Zeldovich(u, theta, dt):
    # du/dt = (theta) u^2(1 - u)
    def du_dt(u):
        return theta * u**2 * (1 - u)
    
    u_next_implicit = u + 0.5 * (du_dt(u) + du_dt(u + du_dt(u) * dt)) * dt
    return u_next_implicit

def reaction(u, rho, epsilon, theta, dt):
    # du/dt = theta * u^2 * (1 - u) + epsilon * (u - u^3) + rho * u * (1 - u)
    def du_dt(u):
        return theta * u**2 * (1 - u) + epsilon * (u - u**3) + rho * u * (1 - u)
    
    u = u + 0.5 * (du_dt(u) + du_dt(u + du_dt(u) * dt)) * dt
    return u

def diffusion(u, nu, dt, IKX2):
    # du/dt = nu*d2u/dx2
    
    factor = np.exp(nu * IKX2 * dt)
    u_hat = np.fft.fft(u)
    u_hat *= factor
    u = np.real(np.fft.ifft(u_hat))
    return u

def convection_diffusion(u, beta, nu, dt, IKX, IKX2):
    # du/dt = -beta*du/dx + nu*d2u/dx2
    
    factor = np.exp(nu * IKX2 * dt - beta * IKX * dt)
    u_hat = np.fft.fft(u)
    u_hat *= factor
    u = np.real(np.fft.ifft(u_hat))
    return u

def convection_diffusion_reaction_discrete_solution(u0 : str, beta, nu, rho, epsilon, theta, nx = 256, nt = 100):

    L = 2*np.pi
    T = 1
    dx = L/nx
    dt = T/nt
    x = np.arange(0, L, dx)
    t = np.linspace(0, T, nt).reshape(-1, 1)
    X, T = np.meshgrid(x, t)
    u = np.zeros((nx, nt))

    IKX_pos = 1j * np.arange(0, nx/2+1, 1)
    IKX_neg = 1j * np.arange(-nx/2+1, 0, 1)
    IKX = np.concatenate((IKX_pos, IKX_neg))
    IKX2 = IKX * IKX

    u0 = function(u0)
    u0 = u0(torch.tensor(x)).numpy()

    u[:,0] = u0
    u_ = u0
    
    for i in range(nt-1):
        u_ = reaction(u_, rho, epsilon, theta, dt)
        u_ = convection_diffusion(u_, beta, nu, dt, IKX, IKX2)
        u[:,i+1] = u_

    u = u.T
    u_vals = u.flatten()
    return u_vals, u