import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
import pandas as pd
import torch
import torch.nn as nn
import torchdiffeq as ode
from cubic_spline import CSpline
from random_fields import GaussianRF
import math
import torch.nn.functional as F
import seaborn as sns
import time

def RBF(x1, x2, params):
    length_scale, output_scale = params 
    diffs = np.expand_dims(x1 / length_scale, 1) - \
            np.expand_dims(x2 / length_scale, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

def generate_gaussain_sample(gp_params,N,T):
    gp_samples = np.zeros((N,T))
    length_scale, output_scale = gp_params 
    jitter = 1e-10
    X = np.linspace(0.0, 1.0, T)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(T))
    for i in range(N):
        gp_sample = np.dot(L, np.random.normal(loc=0,scale=1,size=T))
        gp_samples[i,:] = gp_sample
    return gp_samples

def sav_data(Ntr, Nva, Nte, s_train, u_train, s_valid, u_valid, s_test, u_test,\
             xx, yy, tt, f0_train, f0_valid, f0_test):
    s_train = pd.DataFrame(s_train.reshape(Ntr*Nt,-1))
    u_train = pd.DataFrame(u_train)
    f0_train = pd.DataFrame(f0_train.reshape(Ntr,-1))
    s_train.to_csv("./data_nsPDE/s_train.csv")
    u_train.to_csv("./data_nsPDE/u_train.csv")
    f0_train.to_csv("./data_nsPDE/f0_train.csv")
    
    s_valid = pd.DataFrame(s_valid.reshape(Nva*Nt,-1))
    u_valid = pd.DataFrame(u_valid)
    f0_valid = pd.DataFrame(f0_valid.reshape(Nva,-1))
    s_valid.to_csv("./data_nsPDE/s_valid.csv")
    u_valid.to_csv("./data_nsPDE/u_valid.csv")
    f0_valid.to_csv("./data_nsPDE/f0_valid.csv")
    
    s_test = pd.DataFrame(s_test.reshape(Nte*Nt,-1))
    u_test = pd.DataFrame(u_test)
    f0_test = pd.DataFrame(f0_test.reshape(Nte,-1))
    s_test.to_csv("./data_nsPDE/s_test.csv")
    u_test.to_csv("./data_nsPDE/u_test.csv")
    f0_test.to_csv("./data_nsPDE/f0_test.csv")
    
    xx = pd.DataFrame(xx)
    xx.to_csv("./data_nsPDE/x.csv")
    yy = pd.DataFrame(yy)
    yy.to_csv("./data_nsPDE/y.csv")
    tt = pd.DataFrame(tt)
    tt.to_csv("./data_nsPDE/t.csv")
    print("Data generated successfully!")

class genc2(nn.Module):
    def __init__(self, k_x, k_y, f0):
        super(genc2, self).__init__()
        self.k_x = k_x
        self.k_y = k_y
        self.f0 = f0
        self.fit_data = None
        self.nu = 1e-3
        
        self.kx1 = (1j*k_x)**1
        self.kx2 = (1j*k_x)**2
        self.ky1 = (1j*k_y)**1
        self.ky2 = (1j*k_y)**2
        
        # Negative Laplace operator
        lap = k_x**2 + k_y**2
        lap[0,0] = 1.0
        
        self.lap = (1.0+0j)*lap
        self.dlap = (1.0+0j)/lap
        
    def forward(self, t, x):
        f0 = self.f0
        ut = self.fit_data.fit(t).expand(f0.shape)
        lap = self.lap; dlap = self.dlap
        kx1 = self.kx1; ky1 = self.ky1; nu = self.nu
        
        f_h = f0*ut
        s_tilde = torch.fft.fft2(x)
        psi_h = s_tilde*dlap
        q = ky1*psi_h
        v = -kx1*psi_h
        w_x = kx1*s_tilde
        w_y = ky1*s_tilde
        F_h = torch.fft.ifft2(q).real * torch.fft.ifft2(w_x).real + \
            torch.fft.ifft2(v).real * torch.fft.ifft2(w_y).real 
        gu = -F_h + f_h - nu*torch.fft.ifft2(lap*s_tilde).real
        return(gu)
    
def gen_one2(sensor,f0):
    tt = torch.linspace(0, T, Nt+1)[:-1]
    xx = torch.linspace(0, L, Nx+1)[:-1]
    yy = torch.linspace(0, L, Ny+1)[:-1]
    
    GRF = GaussianRF(2, Nx, alpha=2.5, tau=7)
    y0 = GRF.sample(1).squeeze(0)
    
    k_max = math.floor(Nx/2.0)
    
    #Wavenumbers in y-direction
    k_y = 2*math.pi/L*torch.cat((torch.arange(start=0, end=k_max, step=1, device=f0.device),\
                    torch.arange(start=-k_max, end=0, step=1, device=f0.device)), 0).repeat(Ny,1)
    #Wavenumbers in x-direction
    k_x = k_y.transpose(0,1)
    
    gend = genc2(k_x,k_y,f0)
    
    funcu = CSpline(tt, sensor[0])
    gend.fit_data = funcu
    
    sol = ode.odeint(gend, y0, tt, rtol=1e-6, atol=1e-8, method='dopri5')
    return(xx, yy, tt, sol)

Ntr = 10
Nva = 3
Nte = 3
Nx = 32
Ny = 32
Nt = 1000
T = 20
L = 2
#L = 32*np.pi

length_scale = 0.1
output_scale = 1
gp_params = (length_scale, output_scale)

#Forcing function: 0.1*(sin(2pi(x+y)) + cos(2pi(x+y)))
t = torch.linspace(0, L, Nx+1)
t = t[0:-1]
X,Y = torch.meshgrid(t, t, indexing='ij')

ext_threshold = 100
u_train = np.zeros((Ntr,Nt))
s_train = np.zeros((Ntr,Nt,Nx,Ny))
f0_train = np.zeros((Ntr,Nx,Ny))
u_valid = np.zeros((Nva,Nt))
s_valid = np.zeros((Nva,Nt,Nx,Ny))
f0_valid = np.zeros((Nva,Nx,Ny))
u_test = np.zeros((Nte,Nt))
s_test = np.zeros((Nte,Nt,Nx,Ny))
f0_test = np.zeros((Nte,Nx,Ny))


i = 0
while i<Ntr:
    print("train",i,Ntr)
    try:
        GRF = GaussianRF(2, Nx, alpha=2.5, tau=7)
        f0 = GRF.sample(1).squeeze(0)
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        xx,yy,tt,s = gen_one2(sensor,f0)
        u_train[i] = sensor 
        s_train[i] = s.numpy()
        f0_train[i] = f0
    except:
        print("error")
    else:
        if np.isnan(s_train[i]).any():
            print("isnan")
        elif np.max(s_train[i])>ext_threshold:
            print("ext_threshold")
        else:
            i=i+1     
            
i = 0
while i<Nva:
    print("valid",i,Nva)
    try:
        GRF = GaussianRF(2, Nx, alpha=2.5, tau=7)
        f0 = GRF.sample(1).squeeze(0)
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        xx,yy,tt,s = gen_one2(sensor,f0)
        u_valid[i] = sensor 
        s_valid[i] = s.numpy()
        f0_valid[i] = f0
    except:
        print("error")
    else:
        if np.isnan(s_valid[i]).any():
            print("isnan")
        elif np.max(s_valid[i])>ext_threshold:
            print("ext_threshold")
        else:
            i=i+1     

i = 0
while i<Nte:
    print("test",i,Nte)
    try:
        GRF = GaussianRF(2, Nx, alpha=2.5, tau=7)
        f0 = GRF.sample(1).squeeze(0)
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        xx,yy,tt,s = gen_one2(sensor,f0)
        u_test[i] = sensor 
        s_test[i] = s.numpy()
        f0_test[i] = f0
    except:
        print("error")
    else:
        if np.isnan(s_test[i]).any():
            print("isnan")
        elif np.max(s_test[i])>ext_threshold:
            print("ext_threshold")
        else:
            i=i+1   

sav_data(Ntr, Nva, Nte, s_train, u_train, s_valid, u_valid, s_test, u_test, xx, yy, tt,\
         f0_train, f0_valid, f0_test)

s = s_train[0]
nod = np.arange(0,Nt,10)
with plt.ioff():
    for i in nod:
        fig = plt.figure(figsize=(10,10))
        u = s[i]
        data = pd.DataFrame(u, index=np.linspace(0,L,Nx), columns=np.linspace(0,L,Ny))
        cmap = sns.heatmap(data,center=u.mean())
        plt.savefig("./other/NSpde_NODE"+str(i)+".jpg")   

