import numpy as np
import matplotlib.pyplot as plt
import sympy as sp
import pandas as pd
import torch
import torch.nn as nn
import torchdiffeq as ode
from cubic_spline import CSpline
from torch.fft import fft, ifft

def RBF(x1, x2, params):
    length_scale, output_scale = params 
    diffs = np.expand_dims(x1 / length_scale, 1) - \
            np.expand_dims(x2 / length_scale, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

def generate_gaussain_sample(gp_params,N,T):
    gp_samples = np.zeros((N,T))
    length_scale, output_scale = gp_params 
    jitter = 1e-10
    X = np.linspace(0.0, 1.0, T)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(T))
    for i in range(N):
        gp_sample = np.dot(L, np.random.normal(loc=0,scale=1,size=T))
        gp_samples[i,:] = gp_sample
    return gp_samples

def sav_data(Ntr, Nte, Nva, s_train, u_train, s_valid, u_valid, s_test, u_test, x, t):
    s_train = pd.DataFrame(s_train.reshape(Ntr,-1))
    u_train = pd.DataFrame(u_train)
    s_train.to_csv("./data_drPDE/s_train.csv")
    u_train.to_csv("./data_drPDE/u_train.csv")
    
    s_valid = pd.DataFrame(s_valid.reshape(Nva,-1))
    u_valid = pd.DataFrame(u_valid)
    s_valid.to_csv("./data_drPDE/s_valid.csv")
    u_valid.to_csv("./data_drPDE/u_valid.csv")
    
    s_test = pd.DataFrame(s_test.reshape(Nte,-1))
    u_test = pd.DataFrame(u_test)
    s_test.to_csv("./data_drPDE/s_test.csv")
    u_test.to_csv("./data_drPDE/u_test.csv")
    
    x = pd.DataFrame(x)
    x.to_csv("./data_drPDE/x.csv")
    t = pd.DataFrame(t)
    t.to_csv("./data_drPDE/t.csv")
    print("Data generated successfully!")

class genc(nn.Module):
    def __init__(self, kk, cc, ut0):
        super(genc, self).__init__()
        self.D = 0.01
        self.K = 0.01
        self.fit_data=None 
        self.psup = torch.zeros(len(kk),1)
        self.kk = kk
        self.cc = self.ctom(cc)
        self.fit_data = None
        self.kk1 = self.ctom((1j*kk)**1)
        self.kk2 = self.ctom((1j*kk)**2)
        self.kk3 = self.ctom((1j*kk)**3)
        self.kk4 = self.ctom((1j*kk)**4)
        self.conv1d = torch.nn.functional.conv1d
        self.lent = len(kk)
        self.ut0 = ut0
    
    def ctom(self, c):
        m = torch.cat((c.real.unsqueeze(1),c.imag.unsqueeze(1)),axis=1)
        return(m)
    
    def cmult(self, x1, x2):
        real = x1[:,0]*x2[:,0] - x1[:,1]*x2[:,1]
        imag = x1[:,0]*x2[:,1] + x1[:,1]*x2[:,0]
        m = torch.cat((real.unsqueeze(1),imag.unsqueeze(1)),axis=1)
        return(m)
    
    def conv(self, x1, x2):
        real = self.conv1d(torch.hstack((x1[1:,0],x1[:,0])).view(1, 1, -1),
                      x2[:,0].view(1, 1, -1).flip(2))
        real -= self.conv1d(torch.hstack((x1[1:,1],x1[:,1])).view(1, 1, -1), 
                      x2[:,1].view(1, 1, -1).flip(2))
        imag = self.conv1d(torch.hstack((x1[1:,0],x1[:,0])).view(1, 1, -1),
                      x2[:,1].view(1, 1, -1).flip(2))
        imag += self.conv1d(torch.hstack((x1[1:,1],x1[:,1])).view(1, 1, -1),
                      x2[:,0].view(1, 1, -1).flip(2))
        m = torch.cat((real.view(-1,1),imag.view(-1,1)),axis=1)/self.lent
        return(m)
    
    def forward(self, t, x):
        fiu = self.fit_data.fit(t)*self.lent
        ut = torch.zeros(self.lent,1)
        ut[0] = fiu
        u = torch.cat((ut,self.psup),axis=1) + self.ut0
        #u = torch.cat((ut,self.psup),axis=1)
        gu = self.D*self.cmult(self.kk2, x) + self.K*self.conv(x,x) + 1.0*u
        return(gu)

def gen_one(xx,yy,u):
    u0 = torch.sin(2*np.pi*xx/L)
    tt = torch.linspace(0, T, Nt)
    
    ww = torch.fft.fft(yy).unsqueeze(0) 
    wwm = torch.cat((ww.real,ww.imag),axis=0).t()
    uu = torch.fft.fft(u0).unsqueeze(0)
    ut0 = torch.cat((uu.real,uu.imag),axis=0).t()
    
    cc = torch.fft.fft(torch.ones(Nx))
    k = torch.tensor(kk)
    gend = genc(k,cc,ut0)
    
    funcu = CSpline(torch.linspace(0,T,Nt), u[0])
    gend.fit_data = funcu
    
    sol = ode.odeint(gend, wwm, tt, rtol=1e-6, atol=1e-8, method='dopri5')
    #sol = ode.odeint(gend, wwm, tt, rtol=1e-6, atol=1e-8, method='euler',options={'step_size':0.001})
    sol_v = torch.complex(sol[:,:,0], sol[:,:,1]) 
    sol_table = torch.fft.ifft(sol_v,axis=1).numpy()
    return(xx, tt, sol_table)

class genc2(nn.Module):
    def __init__(self, kk, u0):
        super(genc2, self).__init__()
        self.D = 0.01
        self.K = 0.01
        self.kk = kk
        self.u0 = u0
        self.fit_data=None 
    
    def forward(self, t, x):
        u0 = self.u0
        ut = self.fit_data.fit(t).expand(u0.shape)
        u_b = u0+ut
        s_tilde = fft(x,axis=1)
        
        gu = self.D*ifft(-1*self.kk**2*s_tilde,axis=1).real +\
            self.K*x**2 + u_b #dr方程 
        return(gu)
    
def gen_one2(xx,yy,u):
    #yy = torch.cos(2*np.pi*xx/L)
    u0 = torch.sin(2*np.pi*xx/L).unsqueeze(0)
    tt = torch.linspace(0, T, Nt)
    
    k = torch.tensor(kk).unsqueeze(0)
    gend = genc2(k,u0)
    
    funcu = CSpline(tt, u[0])
    gend.fit_data = funcu
    
    sol = ode.odeint(gend, yy.unsqueeze(0), tt, rtol=1e-6, atol=1e-8, method='dopri5')
    #sol = ode.odeint(gend, wwm, tt, rtol=1e-6, atol=1e-8, method='euler',options={'step_size':0.01})
    return(xx, tt, sol[:,0].numpy())

Ntr = 10
Nva = 2
Nte = 2
Nx = 32
Nt = 1000
T = 1
L = 1
kk = np.concatenate((np.arange(0, Nx/2),np.array([0]),np.arange(-Nx/2+1, 0)))*2*np.pi/L
xx = torch.linspace(0, L, Nx+1)[:-1]

length_scale = 0.1
output_scale = 1.0
gp_params = (length_scale, output_scale)


'''
y0 = torch.cos(2*np.pi*xx/L)
i = 0
while i <= 0:
    try:
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        sensor = sensor - torch.min(sensor)
        x,t,s = gen_one2(xx,y0,sensor)
        #x,t,s = gen_one2(xx,torch.tensor(s[-1]),sensor)
        #x,t,s = gen_one2(xx,torch.tensor(s[-1]),sensor)
        #x,t,s = gen_one2(xx,torch.tensor(s[-1]),sensor)
        #x,t,s = gen_one2(xx,torch.tensor(s[-1]),sensor)
        #x,t,s = gen_one2(xx,torch.tensor(s[-1]),sensor)
        #x,t,s = gen_one2(xx,torch.tensor(s[-1]),sensor)
        yy = torch.tensor(s[-1])
    except:
        print("error")
    else:
        print("pass!")
        i = 1
'''

ext_threshold = 100
u_train = np.zeros((Ntr,Nt))
s_train = np.zeros((Ntr,Nx,Nt))
i = 0
y0 = torch.cos(2*np.pi*xx/L)
while i<Ntr:
    print("train",i,Ntr)
    try:
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        x,t,s = gen_one2(xx,y0,sensor)
        u_train[i] = sensor 
        s_train[i] = np.real(s).transpose()
        #y0 = torch.tensor(s_train[i][:,-1])
    except:
        print("error")
    else:
        if np.isnan(s_train[i]).any():
            print("isnan")
        elif np.max(s_train[i])>ext_threshold:
            print("ext_threshold")
        else:
            i=i+1     

u_valid = np.zeros((Nva,Nt))
s_valid = np.zeros((Nva,Nx,Nt))
i = 0
y0 = torch.cos(2*np.pi*xx/L)
while i<Nva:
    print("valid",i,Nva)
    try:
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        x,t,s = gen_one2(xx,y0,sensor)
        u_valid[i] = sensor 
        s_valid[i] = np.real(s).transpose()
        #y0 = torch.tensor(s_valid[i][:,-1])
    except:
        print("error")
    else:
        if np.isnan(s_valid[i]).any():
            print("isnan")
        elif np.max(s_valid[i])>ext_threshold:
            print("ext_threshold")
        else:
            i=i+1 

u_test = np.zeros((Nte,Nt))
s_test = np.zeros((Nte,Nx,Nt))
i = 0
y0 = torch.cos(2*np.pi*xx/L)
while i<Nte:
    print("test",i,Nte)
    try:
        sensor = torch.tensor(generate_gaussain_sample(gp_params, 1, Nt))
        x,t,s = gen_one2(xx,y0,sensor)
        u_test[i] = sensor 
        s_test[i] = np.real(s).transpose()
        #y0 = torch.tensor(s_test[i][:,-1])
    except:
        print("error")
    else:
        if np.isnan(s_test[i]).any():
            print("isnan")
        elif np.max(s_test[i])>ext_threshold:
            print("ext_threshold")
        else:
            i=i+1   

sav_data(Ntr, Nte, Nva, s_train, u_train, s_valid, u_valid, s_test, u_test, x, t)

i = 0
s = s_train[i].transpose()
xx_mesh, tt_mesh = np.meshgrid(x, t)
fig, ax = plt.subplots(2, 1, figsize=(20, 12))
fig.set_tight_layout(True)
ax[0].pcolormesh(tt_mesh, xx_mesh, s, vmin=s.min(), vmax=s.max())
ax[1].plot(u_train[i])
plt.show()

