# -*- coding: utf-8 -*-
"""
Created on Wed Aug  2 11:28:07 2023

@author: xiamingtao
"""

# -*- coding: utf-8 -*-
"""
Created on Mon Jul 31 21:15:51 2023

@author: xiamingtao
"""

import torch
import torchsde
import pandas as pd
import random
import time
import csv

file_path = '../data/nsde_ground_truth.csv'

csv_data = []
csvreader = csv.reader(file_path)

with open(file_path, 'r') as csvfile:
    # Create a CSV reader object.
    csvreader = csv.reader(csvfile)
    i0 = 0
    # Iterate through each row in the CSV file and append it to the csv_data list.
    for row in csvreader:
        if i0 == 0:
            i0 += 1
            continue
        
        temp = []
        j0 = 0
        for i in row:
            if j0 == 0:
                j0 += 1
                continue
            
            temp.append(float(i))
            
        csv_data.append(temp)
        
    # Iterate through each row in the CSV file and append it to the csv_data list.
#for row in csvreader:


    
# Read the CSV file into a DataFrame.
#df = pd.read_csv(file_path)

# Now you can work with the DataFrame, perform operations, and analyze the data.
#print(df)

batch_size, state_size, brownian_size = 100, 1, 1
H = 100
t_size = 201

class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, H)
        self.linear3 = torch.nn.Linear(H, D_out)
        
        
        #return data
    
    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu1 = self.linear1(x)
        h_relu1 = torch.relu(h_relu1)
        h_relu2 = self.linear2(h_relu1)
        h_relu2 = torch.relu(h_relu2)
        y_pred = self.linear3(h_relu2)
        return y_pred
    
class SDE(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'

    def __init__(self):
        super().__init__()
        self.mu = TwoLayerNet(state_size, H,#torch.nn.Linear(state_size, 
                                  state_size)
        self.sigma = TwoLayerNet(state_size, H,#torch.nn.Linear(state_size, 
                                  state_size * brownian_size)#torch.nn.Linear(state_size, 
                                     #state_size * brownian_size)

    # Drift
    def f(self, t, y):
        #print(self.mu(y).shape)
        return self.mu(y)  # shape (batch_size, state_size)

    # Diffusion
    def g(self, t, y):
        #print(self.sigma(y).view(batch_size, 
        #                          state_size, 
        #                          brownian_size).shape)
        return torch.abs(self.sigma(y).view(batch_size, 
                                  state_size, 
                                  brownian_size))
    
a = 1
b = 5
sigma = 0.5
from math import cos

class SDE_truth(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'
    
    def __init__(self):
        super().__init__()
        self.mu = TwoLayerNet(state_size, H,#torch.nn.Linear(state_size, 
                                  state_size)
        self.sigma = TwoLayerNet(state_size, H,#torch.nn.Linear(state_size, 
                                  state_size * brownian_size)#torch.nn.Linear(state_size, 
                                     #state_size * brownian_size)

    # Drift
    def f(self, t, y):
        #print(y.shape)
        f_truth = torch.zeros(y.shape[0], y.shape[1])
        for i in range(y.shape[0]):
            f_truth[i, 0] = -cos(y[i, 0]) + 0.5#a * (b - y[i, 0])
            
        return f_truth  # shape (batch_size, state_size)

    # Diffusion
    def g(self, t, y):
        g_truth = torch.zeros(y.shape[0], y.shape[1], brownian_size)
        for i in range(y.shape[0]):
            g_truth[i, 0, 0] = 1.0#sigma * y[i, 0]
            
        return g_truth#self.sigma(y).view(batch_size, 
               #                   state_size, 
               #                   brownian_size)

sde = SDE()
#sde.load_state_dict(torch.load('nsde_model1_5.pkl'))
sde_truth = SDE_truth()
y0 = torch.zeros(batch_size, state_size)
for i in range(batch_size):
    y0[i, 0] = 0.1#2 + float(torch.randn(1)[0]) * 0.5 #csv_data[i][0] 
    if y0[i, 0] <= 0:
        print("error")
        
    #torch.full((batch_size, state_size), 0.1)
ts = torch.linspace(0, 20, t_size)
# Initial state y0, the SDE is solved over the interval [ts[0], ts[-1]].
# ys will have shape (t_size, batch_size, state_size)
ys_truth = torch.zeros(t_size, batch_size, state_size)
#import pdb
#print(csv_data[0])
#pdb.set_trace()
for i in range(t_size):
    for j in range(batch_size):
        ys_truth[i, j] = csv_data[i][j]


# generate new trajectories
#ys_truth = torchsde.sdeint(sde_truth, y0, ts)
    
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(sde.parameters(), lr=0.001,betas= (0.9, 0.999), weight_decay=0.005)
epoch = 500
loss_list = []
ttime = time.time()
g_error_list = []
sigma_error_list = []

def dynamics_error(g_error_list, sigma_error_list):
    g_error_ep = []
    sigma_error_ep = []
    for i in range(t_size):
        g_pre = sde.f(0, ys_truth[i])
        g_truth = sde_truth.f(0, ys_truth[i])
        sigma_pre = sde.g(0, ys_truth[i])
        sigma_truth = sde_truth.g(0, ys_truth[i])
        g_error_ep.append(criterion(g_pre, g_truth).item() / batch_size)
        sigma_error_ep.append(criterion(sigma_pre, sigma_truth).item() / batch_size)
    
    g_error_list.append(g_error_ep)
    sigma_error_list.append(sigma_error_ep)



def W_2_distance(ys, ys_truth, i):
    #for i in range(ys_truth.shape[0]):
    
    if ys[i].shape == ys_truth[i].shape:
        ys_truth_sort = torch.zeros(ys_truth.shape[1], ys_truth.shape[2])
        bs = torch.tensor([float(ys[i][j][0]) for j in range(ys[i].shape[0])]).sort().indices
        ys_truth_re = torch.tensor([float(ys_truth[i][j][0]) for j in range(ys_truth[i].shape[0])])
        bs_truth = ys_truth_re.sort().indices
        ys_slice = torch.zeros(ys_truth.shape[1], ys_truth.shape[2])
        for j in range(ys_truth.shape[1]):
                #for r in range(ys_truth.shape[1]):
                #    if bs[j] == bs_truth[r]:
            ys_slice[j, 0] += ys[i, j, 0]
            ys_truth_sort[int(bs[j]),0] = ys_truth[i,int(bs_truth[j]),0]
            
        return ys_slice, ys_truth_sort
    else:
        ys_slice = torch.zeros(ys_truth.shape[1]*ys.shape[1], ys_truth.shape[2])
        ys_truth_slice = torch.zeros(ys_truth.shape[1]*ys.shape[1], ys_truth.shape[2])
        for j0 in range(int(ys.shape[1])):
            for i0 in range(int(ys_truth.shape[1])):
                    
                ys_truth_slice[(i0-1)*int(ys.shape[1]) + j0, 0] = ys_truth[i, i0, 0]
                ys_slice[(j0-1)*int(ys_truth[1]) + i0, 0] += ys_truth[i, j0, 0]
        
        bs = ys_slice.sort().indices
        ys_truth_re = torch.tensor([float(ys_truth[i][j][0]) for j in range(ys_truth[i].shape[0])])
        bs_truth = ys_truth_slice.sort().indices
        ys_truth_sort = torch.zeros(ys_truth.shape[1]*ys.shape[1], ys_truth.shape[2])
        for j in range(ys_truth_slice.shape[0]):
            ys_truth_sort[int(bs[j]),0] = ys_truth_slice[int(bs_truth[j]), 0]
        
        
        return ys_slice, ys_truth_sort
        
    

import pdb
for i0 in range(epoch):
    #print(i)
    ys = torchsde.sdeint(sde, y0, ts)   
    bs = ys.sort()
    ys_truth_sort = torch.zeros(ys_truth.shape[0], ys_truth.shape[1], ys_truth.shape[2])
    for i in range(ys_truth.shape[0]):
        bs = torch.tensor([float(ys[i][j][0]) for j in range(ys[i].shape[0])]).sort().indices
        ys_truth_re = torch.tensor([float(ys_truth[i][j][0]) for j in range(ys_truth[i].shape[0])])
        bs_truth = ys_truth_re.sort().indices
        for j in range(ys_truth.shape[1]):
            #for r in range(ys_truth.shape[1]):
            #    if bs[j] == bs_truth[r]:
                    
                ys_truth_sort[i,int(bs[j]),0] = ys_truth[i,int(bs_truth[j]),0]
    
    
    #print(W_2_distance(ys, ys_truth, 0)[0].shape, W_2_distance(ys, ys_truth, 0)[1].shape)
    #pdb.set_trace()
    loss = criterion(W_2_distance(ys, ys_truth, 0)[0], W_2_distance(ys, ys_truth, 0)[1])
    for i in range(1, t_size):
        loss += criterion(W_2_distance(ys, ys_truth, i)[0], W_2_distance(ys, ys_truth, i)[1])
        
    loss = criterion(ys, ys_truth_sort)
    if i0 % 5 == 0:
        print(i0, loss.item(), time.time() - ttime)
        dynamics_error(g_error_list, sigma_error_list)
        ttime = time.time()
        
    loss_list.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

y = torch.tensor([[0.05 * i -2] for i in range(100)])
g_re = sde.f(0, y)
sigma_re = sde.g(0, y)
g_truth = sde_truth.f(0, y)
sigma_truth = sde_truth.g(0, y)
#print(g_re.shape)

g_truth_data = pd.DataFrame(data = [float(g_truth[i, 0]) for i in range(batch_size)])
g_truth_data.to_csv('g_truth2.csv', header=False, index = False)
g_re_data = pd.DataFrame(data = [float(g_re[i, 0]) for i in range(batch_size)])
g_re_data.to_csv('g_re2.csv', header=False, index = False)
sigma_truth_data = pd.DataFrame(data = [float(sigma_truth[i, 0]) for i in range(batch_size)])
sigma_truth_data.to_csv('sigma_truth2.csv', header=False, index = False)
sigma_re_data = pd.DataFrame(data = [float(sigma_re[i, 0]) for i in range(batch_size)])
sigma_re_data.to_csv('sigma_re2.csv', header=False, index = False)

truth_data = pd.DataFrame(data = [[float(ys_truth[i, j, 0]) for i in range(t_size)] for j in range(batch_size)])
truth_data.to_csv('ground_truth_example2.csv', header=False, index = False)

predict_data = pd.DataFrame(data = [[float(ys[i, j, 0]) for i in range(t_size)] for j in range(batch_size)])
predict_data.to_csv('predict_example2.csv', header=False, index = False)

loss_data = pd.DataFrame(data = loss_list)
loss_data.to_csv('loss_example2.csv', header=False, index = False)

g_error = pd.DataFrame(data = g_error_list)
g_error.to_csv('g_error_example2.csv', header=False, index = False)

sigma_error = pd.DataFrame(data = sigma_error_list)
sigma_error.to_csv('sigma_error_example2.csv', header=False, index = False)

torch.save(sde.state_dict(), 'nsde_model2' +str(a) + '_' + str(b)+'.pkl')