import numpy as np
import random as random
import torch
import torch.nn as nn 
import pandas as pd
import torchdiffeq as ode
import matplotlib.pyplot as plt
from cubic_spline import CSpline

na = "experiment1_ODE"

class ode1(nn.Module):
    def __init__(self):
        super(ode1, self).__init__()
        self.u = None

    def forward(self, t, x):
        #g = -x**2 + self.u.fit(t)
        g = -x**2 + self.u.fit(t)*x
        return g

def RBF(x1, x2, params):
    length_scale, output_scale = params 
    diffs = np.expand_dims(x1 / length_scale, 1) - \
            np.expand_dims(x2 / length_scale, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

def generate_gaussain_sample(gp_params,N,T):
    gp_samples = np.zeros((N,T))
    length_scale, output_scale = gp_params 
    jitter = 1e-10
    X = np.linspace(0.0, 1.0, T)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(T))
    for i in range(N):
        gp_sample = np.dot(L, np.random.normal(loc=0,scale=1,size=T))
        gp_samples[i,:] = gp_sample
    return gp_samples

def rand_gp():
    length_scale = np.random.random()+0.2
    output_scale = np.random.random()*10
    gp_params = (length_scale, output_scale)
    return(gp_params)
    

def sav_data(s_train, u_train, s_test, u_test, t):
    s_train = pd.DataFrame(s_train.numpy())
    u_train = pd.DataFrame(u_train.numpy())
    s_train.to_csv("./data_"+na+"/s_train.csv")
    u_train.to_csv("./data_"+na+"/u_train.csv")
    s_test = pd.DataFrame(s_test.numpy())
    u_test = pd.DataFrame(u_test.numpy())
    s_test.to_csv("./data_"+na+"/s_test.csv")
    u_test.to_csv("./data_"+na+"/u_test.csv")
    t = pd.DataFrame(t.numpy())
    t.to_csv("./data_"+na+"/t.csv")
    print("Data generated successfully!")

random.seed(1)
torch.manual_seed(1)

V = 1
zeta = 0
tpoint = 1000
utpoint = 1000
Ntr = 5
Nte = 2
N_extra = 2
ext_threshod = 20 
func = ode1()

################################################################
# Generate data for training and the interpolation tasks
################################################################
T = 1
t = torch.linspace(0,T,tpoint)

##### train data #####
u_train = torch.zeros(Ntr,utpoint)
s_train = torch.zeros(Ntr,tpoint)
i = 0
with torch.no_grad():
    while i < Ntr:
        print('train', i)
        gp_params = rand_gp()
        u_train[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, utpoint)[0])
        func.u  = CSpline(torch.linspace(0,T,utpoint), torch.tensor(u_train[i])) 
        try:
            s0 = (torch.rand(1)-0.5)*2
            s_train[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')[:,0]
        except:
            print("error")
        else:
            if torch.isnan(s_train[i]).any() or torch.max(torch.abs(s_train[i]))>ext_threshod:
                print(i)
            else:
                i=i+1

##### test data #####
u_test = torch.zeros(Nte,utpoint)
s_test = torch.zeros(Nte,tpoint)
i = 0
with torch.no_grad():
    while i < Nte:
        print('test', i)
        gp_params = rand_gp()
        u_test[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, utpoint)[0])
        func.u  = CSpline(torch.linspace(0,T,utpoint), torch.tensor(u_test[i]))
        try:
            s0 = (torch.rand(1)-0.5)*2
            s_test[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')[:,0]
        except:
            print("error")
        else:
            if torch.isnan(s_test[i]).any() or torch.max(torch.abs(s_test[i]))>ext_threshod:
                print(i)
            else:
                i=i+1
                
sav_data(s_train,u_train,s_test,u_test,t)

'''
################################################################
# Generate data for the extrapolation tasks
################################################################
def sav_extra(s_extra1, u_extra1, s_extra2, u_extra2, t_extra1, t_extra2):
    s_extra1 = pd.DataFrame(s_extra1.numpy())
    u_extra1 = pd.DataFrame(u_extra1.numpy())
    s_extra1.to_csv("./data_"+na+"/s_extra1.csv")
    u_extra1.to_csv("./data_"+na+"/u_extra1.csv")
    s_extra2 = pd.DataFrame(s_extra2.numpy())
    u_extra2 = pd.DataFrame(u_extra2.numpy())
    s_extra2.to_csv("./data_"+na+"/s_extra2.csv")
    u_extra2.to_csv("./data_"+na+"/u_extra2.csv")
    t_extra1 = pd.DataFrame(t_extra1.numpy())
    t_extra1.to_csv("./data_"+na+"/t_extra1.csv")
    t_extra2 = pd.DataFrame(t_extra2.numpy())
    t_extra2.to_csv("./data_"+na+"/t_extra2.csv")
    print("Data generated successfully!")

##### extra1 data #####
ext_threshod = 40 
length_scale = 0.7
output_scale = 12.0
gp_params = (length_scale, output_scale)
t_extra1 = torch.linspace(0,T,tpoint)

u_extra1 = torch.zeros(N_extra,utpoint)
s_extra1 = torch.zeros(N_extra,tpoint)
i = 0
with torch.no_grad():
    while i < N_extra:
        print('extra1', i)
        u_extra1[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, utpoint)[0])
        func.u  = CSpline(torch.linspace(0,T,utpoint), torch.tensor(u_extra1[i]))
        try:
            s0 = (torch.rand(1)-0.5)*2
            s_extra1[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')[:,0]
        except:
            print("error")
        else:
            if torch.isnan(s_extra1[i]).any() or torch.max(torch.abs(s_extra1[i]))>ext_threshod:
                print(i)
            else:
                i=i+1

##### extra2 data #####
length_scale = 0.7
output_scale = 14.0
gp_params = (length_scale, output_scale)
t_extra2 = torch.linspace(0,T,tpoint)

u_extra2 = torch.zeros(N_extra,utpoint)
s_extra2 = torch.zeros(N_extra,tpoint)
i = 0
with torch.no_grad():
    while i < N_extra:
        print('extra2', i)
        u_extra2[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, utpoint)[0])
        func.u  = CSpline(torch.linspace(0,T,utpoint), torch.tensor(u_extra2[i]))
        try:
            s0 = (torch.rand(1)-0.5)*2
            s_extra2[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')[:,0]
        except:
            print("error")
        else:
            if torch.isnan(s_extra2[i]).any() or torch.max(torch.abs(s_extra2[i]))>ext_threshod:
                print(i)
            else:
                i=i+1

sav_extra(s_extra1, u_extra1, s_extra2, u_extra2, t_extra1, t_extra2)
'''


'''
nn = 4
for i in range(0,s_extra1.shape[0],nn):
    plt.plot(t_extra1,s_extra1[i],'o')
for i in range(0,s_extra2.shape[0],nn):
    plt.plot(t_extra2,s_extra2[i],'o')
'''





