import numpy as np
import random as random
import torch
import torch.nn as nn 
import pandas as pd
import torchdiffeq as ode
import matplotlib.pyplot as plt
from cubic_spline import CSpline

na = "E2"

class ode1(nn.Module):
    def __init__(self):
        super(ode1, self).__init__()
        self.u = None

    def forward(self, t, x):
        rho = self.u.fit(t)
        g = torch.ones(3)
        g[0] = 10*(x[1]-x[0])
        g[1] = x[0]*(rho-x[2])-x[1]
        g[2] = x[0]*x[1]-8/3*x[2]
        return g

def RBF(x1, x2, params):
    length_scale, output_scale = params 
    diffs = np.expand_dims(x1 / length_scale, 1) - \
            np.expand_dims(x2 / length_scale, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

def generate_gaussain_sample(gp_params,N,T):
    gp_samples = np.zeros((N,T))
    length_scale, output_scale = gp_params 
    jitter = 1e-10
    X = np.linspace(0.0, 1.0, T)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(T))
    for i in range(N):
        gp_sample = np.dot(L, np.random.normal(loc=0,scale=1,size=T))
        gp_samples[i,:] = gp_sample
    return gp_samples

def rand_gp():
    length_scale = np.random.random()+0.2
    output_scale = np.random.random()*10
    gp_params = (length_scale, output_scale)
    return(gp_params)

def sav_data(s_train, u_train, s_valid, u_valid, s_test, u_test, t):
    s_train = pd.DataFrame(s_train.numpy().reshape(s_train.shape[0],-1))
    u_train = pd.DataFrame(u_train.numpy())
    s_train.to_csv("./data_"+na+"/s_train.csv")
    u_train.to_csv("./data_"+na+"/u_train.csv")
    s_valid = pd.DataFrame(s_valid.numpy().reshape(s_valid.shape[0],-1))
    u_valid = pd.DataFrame(u_valid.numpy())
    s_valid.to_csv("./data_"+na+"/s_valid.csv")
    u_valid.to_csv("./data_"+na+"/u_valid.csv")
    s_test = pd.DataFrame(s_test.numpy().reshape(s_test.shape[0],-1))
    u_test = pd.DataFrame(u_test.numpy())
    s_test.to_csv("./data_"+na+"/s_test.csv")
    u_test.to_csv("./data_"+na+"/u_test.csv")
    t = pd.DataFrame(t.numpy())
    t.to_csv("./data_"+na+"/t.csv")
    print("Data generated successfully!")

random.seed(1)
torch.manual_seed(1)

V = 1
zeta = 0
tpoint = 1000
Nx = 3
Ntr = 10
Nva = 2
Nte = 2
ext_threshod = 200
meanu = 24.74
func = ode1()

################################################################
# Generate data for training and the interpolation tasks
################################################################
T = 10
t = torch.linspace(0,T,tpoint)
s0 = torch.rand(3)*20

length_scale = 0.3
output_scale = 10.0
gp_params = (length_scale, output_scale)

##### train data #####
u_train = torch.zeros(Ntr,tpoint)
s_train = torch.zeros(Ntr,tpoint,Nx)
i = 0
with torch.no_grad():
    while i < Ntr:
        print('train', i)
        #gp_params = rand_gp()
        u_train[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, tpoint)[0]) + meanu
        func.u  = CSpline(torch.linspace(0,T,tpoint), torch.tensor(u_train[i])) 
        try:
            s_train[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')
        except:
            print("error")
        else:
            if torch.isnan(s_train[i]).any() or torch.max(torch.abs(s_train[i]))>ext_threshod:
                print(i)
            else:
                i=i+1

##### valid data #####
u_valid = torch.zeros(Nva,tpoint)
s_valid = torch.zeros(Nva,tpoint,Nx)
i = 0
with torch.no_grad():
    while i < Nva:
        print('valid', i)
        #gp_params = rand_gp()
        u_valid[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, tpoint)[0]) + meanu
        func.u  = CSpline(torch.linspace(0,T,tpoint), torch.tensor(u_valid[i]))
        try:
            s_valid[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')
        except:
            print("error")
        else:
            if torch.isnan(s_valid[i]).any() or torch.max(torch.abs(s_valid[i]))>ext_threshod:
                print(i)
            else:
                i=i+1

##### test data #####
u_test = torch.zeros(Nte,tpoint)
s_test = torch.zeros(Nte,tpoint,Nx)
i = 0
with torch.no_grad():
    while i < Nte:
        print('test', i)
        #gp_params = rand_gp()
        u_test[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, tpoint)[0]) + meanu
        func.u  = CSpline(torch.linspace(0,T,tpoint), torch.tensor(u_test[i]))
        try:
            s_test[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')
        except:
            print("error")
        else:
            if torch.isnan(s_test[i]).any() or torch.max(torch.abs(s_test[i]))>ext_threshod:
                print(i)
            else:
                i=i+1
                
sav_data(s_train,u_train,s_valid,u_valid,s_test,u_test,t)


i = 5
fig = plt.figure(figsize=(20,15))
ax1 = fig.add_subplot(2,1,1)
ax1.plot(t, s_train[i],'-o')

ax2 = fig.add_subplot(2,1,2)
ax2.plot(t, u_train[i],'-o')

