import numpy as np
import random as random
import torch
import torch.nn as nn 
import pandas as pd
import torchdiffeq as ode
import matplotlib.pyplot as plt
from cubic_spline import CSpline

na = "E1"
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

class ode1(nn.Module):
    def __init__(self, device):
        super(ode1, self).__init__()
        self.true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)
        self.u = None

    def forward(self, t, s):
        ut = self.u.fit(t)
        ft = torch.mm(s**3, self.true_A) + ut
        return ft

def RBF(x1, x2, params):
    length_scale, output_scale = params 
    diffs = np.expand_dims(x1 / length_scale, 1) - \
            np.expand_dims(x2 / length_scale, 0)
    r2 = np.sum(diffs**2, axis=2)
    return output_scale * np.exp(-0.5 * r2)

def generate_gaussain_sample(gp_params,N,T):
    gp_samples = np.zeros((N,T))
    length_scale, output_scale = gp_params 
    jitter = 1e-10
    X = np.linspace(0.0, 1.0, T)[:,None]
    K = RBF(X, X, gp_params)
    L = np.linalg.cholesky(K + jitter*np.eye(T))
    for i in range(N):
        gp_sample = np.dot(L, np.random.normal(loc=0,scale=1,size=T))
        gp_samples[i,:] = gp_sample
    return gp_samples

def rand_gp():
    length_scale = np.random.random()*0.3+0.1
    output_scale = (np.random.random()*0.5+0.5)*100
    gp_params = (length_scale, output_scale)
    return(gp_params)
    
def sav_data(s_train, u_train, s_valid, u_valid, s_test, u_test, t):
    s_train = s_train.reshape(-1,s_train.shape[-1])
    s_train = pd.DataFrame(s_train.numpy())
    u_train = pd.DataFrame(u_train.numpy())
    s_train.to_csv("./data_"+na+"/s_train.csv")
    u_train.to_csv("./data_"+na+"/u_train.csv")
    
    s_test = s_test.reshape(-1,s_test.shape[-1])  
    s_test = pd.DataFrame(s_test.numpy())
    u_test = pd.DataFrame(u_test.numpy())
    s_test.to_csv("./data_"+na+"/s_test.csv")
    u_test.to_csv("./data_"+na+"/u_test.csv")
    
    s_valid = s_valid.reshape(-1,s_valid.shape[-1])
    s_valid = pd.DataFrame(s_valid.numpy())
    u_valid = pd.DataFrame(u_valid.numpy())
    s_valid.to_csv("./data_"+na+"/s_valid.csv")
    u_valid.to_csv("./data_"+na+"/u_valid.csv")
    
    t = pd.DataFrame(t.numpy())
    t.to_csv("./data_"+na+"/t.csv")
    print("Data generated successfully!")

random.seed(1)
torch.manual_seed(1)

V = 2
zeta = 0
tpoint = 1000
utpoint = 1000
Ntr = 100
Nte = 20
Nva = 20
ext_threshod = 100
func = ode1(device)

################################################################
# Generate data for training and the interpolation tasks
################################################################
T = 10
t = torch.linspace(0,T,tpoint+1)[:-1]

length_scale = 0.1
output_scale = 20
##### train data #####
u_train = torch.zeros(Ntr,utpoint)
s_train = torch.zeros(Ntr,tpoint,V)
i = 0
with torch.no_grad():
    while i < Ntr:
        print('train', i)
        gp_params = rand_gp()
        length_scale = 0.1
        output_scale = 20
        u_train[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, utpoint)[0])
        func.u  = CSpline(t, torch.tensor(u_train[i])) 
        try:
            s0 = ((torch.rand(1,2)-0.5)*10).to(device)
            s_train[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')[:,0]
        except:
            print("error")
        else:
            if torch.isnan(s_train[i]).any() or torch.max(torch.abs(s_train[i]))>ext_threshod:
                print(i)
            else:
                i=i+1
                
##### valid data #####
u_valid = torch.zeros(Nva,utpoint)
s_valid = torch.zeros(Nva,tpoint,V)
i = 0
with torch.no_grad():
    while i < Nva:
        print('valid', i)
        #gp_params = rand_gp()
        gp_params = (length_scale, output_scale)
        u_valid[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, utpoint)[0])
        func.u  = CSpline(t, torch.tensor(u_valid[i])) 
        try:
            s0 = ((torch.rand(1,2)-0.5)*10).to(device)
            s_valid[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')[:,0]
        except:
            print("error")
        else:
            if torch.isnan(s_valid[i]).any() or torch.max(torch.abs(s_valid[i]))>ext_threshod:
                print(i)
            else:
                i=i+1

##### test data #####
u_test = torch.zeros(Nte,utpoint)
s_test = torch.zeros(Nte,tpoint,V)
i = 0
with torch.no_grad():
    while i < Nte:
        print('test', i)
        #gp_params = rand_gp()
        gp_params = (length_scale, output_scale)
        u_test[i] = torch.tensor(generate_gaussain_sample(gp_params, 1, utpoint)[0])
        func.u  = CSpline(t, torch.tensor(u_test[i])) 
        try:
            s0 = ((torch.rand(1,2)-0.5)*10).to(device)
            s_test[i] = ode.odeint(func, s0, t, rtol=1e-6, atol=1e-8, method='dopri5')[:,0]
        except:
            print("error")
        else:
            if torch.isnan(s_test[i]).any() or torch.max(torch.abs(s_test[i]))>ext_threshod:
                print(i)
            else:
                i=i+1
                
sav_data(s_train,u_train,s_valid,u_valid,s_test,u_test,t)

i = 0
fig, ax = plt.subplots(4, 1, figsize=(30,20))
ax[0].plot(t,s_train[i,:,0],'r-o')
ax[0].plot(t,s_train[i,:,1],'b-x')
ax[1].plot(s_train[i,:,0],s_train[i,:,1],'g-o')
ax[2].plot(t,u_train[i],'k-o')
ax[3].plot(torch.mm(s_train[i]**3, func.true_A)[0:])






