import math
from timeit import default_timer
from tqdm.auto import tqdm

import torch


class GaussianRF(object):
    """
    Gaussian random field generator
    From: https://github.com/zongyi-li/fourier_neural_operator/blob/master/data_generation/navier_stokes/random_fields.py
    """
    def __init__(self, dim, size, alpha=2, tau=3, sigma=None, boundary="periodic", device=None):
        self.dim = dim
        self.device = device
        if sigma is None:
            sigma = tau**(0.5*(2*alpha - self.dim))
        k_max = size//2

        if dim == 1:
            k = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), \
                           torch.arange(start=-k_max, end=0, step=1, device=device)), 0)
            self.sqrt_eig = size*math.sqrt(2.0)*sigma*((4*(math.pi**2)*(k**2) + tau**2)**(-alpha/2.0))
            self.sqrt_eig[0] = 0.0

        elif dim == 2:
            wavenumers = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), \
                                    torch.arange(start=-k_max, end=0, step=1, device=device)), 0).repeat(size,1)
            k_x = wavenumers.transpose(0,1)
            k_y = wavenumers
            self.sqrt_eig = (size**2)*math.sqrt(2.0)*sigma*((4*(math.pi**2)*(k_x**2 + k_y**2) + tau**2)**(-alpha/2.0))
            self.sqrt_eig[0,0] = 0.0

        elif dim == 3:
            wavenumers = torch.cat((torch.arange(start=0, end=k_max, step=1, device=device), \
                                    torch.arange(start=-k_max, end=0, step=1, device=device)), 0).repeat(size,size,1)
            k_x = wavenumers.transpose(1,2)
            k_y = wavenumers
            k_z = wavenumers.transpose(0,2)
            self.sqrt_eig = (size**3)*math.sqrt(2.0)*sigma*((4*(math.pi**2)*(k_x**2 + k_y**2 + k_z**2) + tau**2)**(-alpha/2.0))
            self.sqrt_eig[0,0,0] = 0.0

        self.size = []
        for j in range(self.dim):
            self.size.append(size)
        self.size = tuple(self.size)

    def sample(self, N=1):
        coeff = torch.randn(N, *self.size, dtype=torch.cfloat, device=self.device)
        coeff = self.sqrt_eig * coeff
        return torch.fft.ifftn(coeff, dim=list(range(-1, -self.dim - 1, -1))).real


def navier_stokes_2d(w0, f, visc, T, delta_t=1e-4, record_steps=1):
    """
    Solve the 2D incompressible Navier-Stokes equations with periodic boundary conditions
    Reference: https://github.com/zongyi-li/fourier_neural_operator/blob/master/data_generation/navier_stokes/ns_2d.py
    
    Args:
        w0: initial condition
        f: forcing term
        visc: viscosity
        T: final time
        delta_t: time step
        record_steps: number of time steps between each recorded solution
    """
    # Grid size - must be power of 2
    N = w0.size()[-1]
    # Maximum frequency
    k_max = math.floor(N / 2.0)
    # Number of steps to final time
    steps = math.ceil(T / delta_t)
    # Initial vorticity to Fourier space
    w_h = torch.fft.fftn(w0, (N, N))
    # Forcing to Fourier space
    f_h = torch.fft.fftn(f, (N, N))
    # If same forcing for the whole batch
    if len(f_h.size()) < len(w_h.size()):
        f_h = torch.unsqueeze(f_h, 0)
    # Record solution every this number of steps
    record_time = math.floor(steps / record_steps)
    # Wavenumbers in y-direction
    k_y = torch.cat((torch.arange(start=0, end=k_max, step=1, device=w0.device),
                    torch.arange(start=-k_max, end=0, step=1, device=w0.device)), 0).repeat(N, 1)
    # Wavenumbers in x-direction
    k_x = k_y.transpose(0, 1)
    # Negative Laplacian in Fourier space
    lap = 4 * (math.pi ** 2) * (k_x ** 2 + k_y ** 2)
    lap[0, 0] = 1.0
    # Dealiasing mask
    dealias = torch.unsqueeze(
        torch.logical_and(torch.abs(k_y) <= (2.0 / 3.0) * k_max, torch.abs(k_x) <= (2.0 / 3.0) * k_max).float(), 0)
    # Saving solution and time
    sol = torch.zeros(*w0.size(), record_steps, 1, device=w0.device, dtype=torch.float)
    sol_t = torch.zeros(record_steps, device=w0.device)
    # Record counter
    c = 0
    # Physical time
    t = 0.0
    for j in range(steps):
        if j % record_time == 0:
            # Solution in physical space
            w = torch.fft.ifftn(w_h, (N, N))
            # Record solution and time
            sol[..., c, 0] = w.real
            # sol[...,c,1] = w.imag
            sol_t[c] = t
            c += 1
        # Stream function in Fourier space: solve Poisson equation
        psi_h = w_h.clone()
        psi_h = psi_h / lap
        # Velocity field in x-direction = psi_y
        q = psi_h.clone()
        temp = q.real.clone()
        q.real = -2 * math.pi * k_y * q.imag
        q.imag = 2 * math.pi * k_y * temp
        q = torch.fft.ifftn(q, (N, N))
        # Velocity field in y-direction = -psi_x
        v = psi_h.clone()
        temp = v.real.clone()
        v.real = 2 * math.pi * k_x * v.imag
        v.imag = -2 * math.pi * k_x * temp
        v = torch.fft.ifftn(v, (N, N))
        # Partial x of vorticity
        w_x = w_h.clone()
        temp = w_x.real.clone()
        w_x.real = -2 * math.pi * k_x * w_x.imag
        w_x.imag = 2 * math.pi * k_x * temp
        w_x = torch.fft.ifftn(w_x, (N, N))
        # Partial y of vorticity
        w_y = w_h.clone()
        temp = w_y.real.clone()
        w_y.real = -2 * math.pi * k_y * w_y.imag
        w_y.imag = 2 * math.pi * k_y * temp
        w_y = torch.fft.ifftn(w_y, (N, N))
        # Non-linear term (u.grad(w)): compute in physical space then back to Fourier space
        F_h = torch.fft.fftn(q * w_x + v * w_y, (N, N))
        # Dealias
        F_h = dealias * F_h
        # Cranck-Nicholson update
        w_h = (-delta_t * F_h + delta_t * f_h + (1.0 - 0.5 * delta_t * visc * lap) * w_h) / \
            (1.0 + 0.5 * delta_t * visc * lap)
        # Update real time (used only for recording)
        t += delta_t
    return sol, sol_t


if __name__ == "__main__":

    print("Running Navier-Stokes 2D test")

    device = torch.device('cuda')

    #Resolution
    s = 256

    #Number of solutions to generate
    N = 20

    #Set up 2d GRF with covariance parameters
    GRF = GaussianRF(2, s, alpha=2.5, tau=7, device=device)

    #Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y)))
    t = torch.linspace(0, 1, s+1, device=device)
    t = t[0:-1]

    X,Y = torch.meshgrid(t, t)#, indexing='ij')
    f = 0.1*(torch.sin(2*math.pi*(X + Y)) + torch.cos(2*math.pi*(X + Y)))

    #Number of snapshots from solution
    record_steps = 200

    #Inputs
    a = torch.zeros(N, s, s)
    #Solutions
    u = torch.zeros(N, s, s, record_steps)

    #Solve equations in batches (order of magnitude speed-up)

    #Batch size
    bsize = 20

    c = 0
    t0 =default_timer()
    for j in tqdm(range(N//bsize)):

        #Sample random feilds
        w0 = GRF.sample(bsize)

        #Solve NS
        sol, sol_t = navier_stokes_2d(w0, f, 1e-3, 50.0, 1e-4, record_steps)

        a[c:(c+bsize),...] = w0
        u[c:(c+bsize),...] = sol

        c += bsize
        t1 = default_timer()
        print(j, c, t1-t0)

    # scipy.io.savemat('ns_data.mat', mdict={'a': a.cpu().numpy(), 'u': u.cpu().numpy(), 't': sol_t.cpu().numpy()}) 
