import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
import pandas as pd
import torch
import torch.nn as nn
import torchdiffeq as ode
from cubic_spline import CSpline
from torch.fft import fft, ifft

def RBF(x1, x2, params):
    length_scale, output_scale = params 
    diffs = np.expand_dims(x1 / length_scale, 1) - \
            np.expand_dims(x2 / length_scale, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

def generate_gaussain_sample(gp_params,N,T):
    gp_samples = np.zeros((N,T))
    length_scale, output_scale = gp_params 
    jitter = 1e-10
    X = np.linspace(0.0, 1.0, T)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(T))
    for i in range(N):
        gp_sample = np.dot(L, np.random.normal(loc=0,scale=1,size=T))
        gp_samples[i,:] = gp_sample
    return gp_samples

def sav_data(Ntr, Nte, Nva, s_train, u_train, s_valid, u_valid, s_test, u_test, x, t):
    s_train = pd.DataFrame(s_train.reshape(Ntr,-1))
    u_train = pd.DataFrame(u_train)
    s_train.to_csv("./data_burPDE/s_train.csv")
    u_train.to_csv("./data_burPDE/u_train.csv")
    
    s_valid = pd.DataFrame(s_valid.reshape(Nva,-1))
    u_valid = pd.DataFrame(u_valid)
    s_valid.to_csv("./data_burPDE/s_valid.csv")
    u_valid.to_csv("./data_burPDE/u_valid.csv")
    
    s_test = pd.DataFrame(s_test.reshape(Nte,-1))
    u_test = pd.DataFrame(u_test)
    s_test.to_csv("./data_burPDE/s_test.csv")
    u_test.to_csv("./data_burPDE/u_test.csv")
    
    x = pd.DataFrame(x)
    x.to_csv("./data_burPDE/x.csv")
    t = pd.DataFrame(t)
    t.to_csv("./data_burPDE/t.csv")
    print("Data generated successfully!")

class genc2(nn.Module):
    def __init__(self, kk, u0):
        super(genc2, self).__init__()
        self.kk = kk
        self.u0 = u0
        self.fit_data=None 
    
    def forward(self, t, x):
        u0 = self.u0
        ut = self.fit_data.fit(t).expand(u0.shape)
        u_b = u0+ut
        s_tilde = fft(x,axis=1)
        
        gu = 1e-3*ifft(s_tilde*self.kk**2*(-1),axis=1).real - u_b*x*ifft(s_tilde*1j*self.kk,axis=1).real
        #gu = 1e-3*ifft(s_tilde*self.kk**2*(-1),axis=1).real - x*ifft(s_tilde*1j*self.kk,axis=1).real
        return(gu)
    
def gen_one2(xx,yy,u):
    #yy = torch.cos(2*np.pi*xx/L)
    u0 = torch.sin(2*np.pi*xx/L).unsqueeze(0)*0.0
    tt = torch.linspace(0, T, Nt)
    
    k = torch.tensor(kk).unsqueeze(0)
    gend = genc2(k,u0)
    
    funcu = CSpline(tt, u[0])
    gend.fit_data = funcu
    
    sol = ode.odeint(gend, yy.unsqueeze(0), tt, rtol=1e-6, atol=1e-8, method='dopri5')
    #sol = ode.odeint(gend, wwm, tt, rtol=1e-6, atol=1e-8, method='euler',options={'step_size':0.01})
    return(xx, tt, sol[:,0].numpy())

Ntr = 10
Nva = 2
Nte = 2
Nx = 64
Nt = 1000
T = 10
L = 4*np.pi
kk = np.concatenate((np.arange(0, Nx/2),np.array([0]),np.arange(-Nx/2+1, 0)))*2*np.pi/L
xx = torch.linspace(0, L, Nx+1)[:-1]

length_scale = 0.2
output_scale = 0.1
minmum = 1e-3
gp_params = (length_scale, output_scale)
yy = 2*torch.cos(xx/L)*(1+torch.sin(xx/L))


i = 0
while i <= 0:
    try:
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt)) 
        sensor = sensor - torch.min(sensor) + minmum
        x,t,s = gen_one2(xx,yy,sensor)
        for j in range(20):
            x,t,s = gen_one2(xx,torch.tensor(s[-1]),sensor)
        yy = torch.tensor(s[-1])
    except:
        print("error")
    else:
        print("pass!")
        i = 1

ext_threshold = 100
u_train = np.zeros((Ntr,Nt))
s_train = np.zeros((Ntr,Nx,Nt))
i = 0
y0 = yy
while i<Ntr:
    print("train",i,Ntr)
    try:
        #sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        sensor = sensor - torch.min(sensor) + minmum
        x,t,s = gen_one2(xx,y0,sensor)
        u_train[i] = sensor 
        s_train[i] = np.real(s).transpose()
        #y0 = torch.tensor(s_train[i][:,-1])
    except:
        print("error")
    else:
        if np.isnan(s_train[i]).any():
            print("isnan")
        elif np.max(s_train[i])>ext_threshold:
            print("ext_threshold")
        else:
            i=i+1     

u_valid = np.zeros((Nva,Nt))
s_valid = np.zeros((Nva,Nx,Nt))
i = 0
y0 = yy
while i<Nva:
    print("valid",i,Nva)
    try:
        #sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        sensor = sensor - torch.min(sensor) + minmum
        x,t,s = gen_one2(xx,y0,sensor)
        u_valid[i] = sensor 
        s_valid[i] = np.real(s).transpose()
        #y0 = torch.tensor(s_valid[i][:,-1])
    except:
        print("error")
    else:
        if np.isnan(s_valid[i]).any():
            print("isnan")
        elif np.max(s_valid[i])>ext_threshold:
            print("ext_threshold")
        else:
            i=i+1 

u_test = np.zeros((Nte,Nt))
s_test = np.zeros((Nte,Nx,Nt))
i = 0
y0 = yy
while i<Nte:
    print("test",i,Nte)
    try:
        #sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        sensor = sensor - torch.min(sensor) + minmum
        x,t,s = gen_one2(xx,y0,sensor)
        u_test[i] = sensor 
        s_test[i] = np.real(s).transpose()
        #y0 = torch.tensor(s_test[i][:,-1])
    except:
        print("error")
    else:
        if np.isnan(s_test[i]).any():
            print("isnan")
        elif np.max(s_test[i])>ext_threshold:
            print("ext_threshold")
        else:
            i=i+1   

sav_data(Ntr, Nte, Nva, s_train, u_train, s_valid, u_valid, s_test, u_test, x, t)

i = 0
s = s_test[i].transpose()
xx_mesh, tt_mesh = np.meshgrid(x, t)
fig, ax = plt.subplots(2, 1, figsize=(20, 12))
fig.set_tight_layout(True)
ax[0].pcolormesh(tt_mesh, xx_mesh, s, vmin=s.min(), vmax=s.max())
ax[1].plot(u_test[i])
plt.show()






