# -*- coding: utf-8 -*-
"""
Created on Mon Jul 31 21:15:51 2023

@author: xiamingtao
"""

import torch
import torchsde
import pandas as pd
import random
import time
import csv


file_path = '../data/nsde_ground_truth.csv'

batch_size, state_size, brownian_size = 200, 1, 1
H = 32
t_size = 41

class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
        
        
        #return data
    
    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu1 = self.linear1(x)
        h_relu1 = torch.relu(h_relu1)
        y_pred = self.linear2(h_relu1)
        return y_pred
    
class neuralSDE(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'

    def __init__(self):
        super().__init__()
        self.mu = TwoLayerNet(state_size, H,
                                  state_size)
        self.sigma = TwoLayerNet(state_size, H,
                                  state_size * brownian_size)
    # Drift
    def f(self, t, y):
        return self.mu(y)  # shape (batch_size, state_size)

    # Diffusion
    def g(self, t, y):
        return self.sigma(y).view(batch_size, 
                                  state_size, 
                                  brownian_size)
    
a = 1
b = 5
sigma = 0.5

class SDE(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'
    
    def __init__(self):
        super().__init__()


    # Drift
    def f(self, t, y):
        #print(y.shape)
        f_truth = torch.zeros(y.shape[0], y.shape[1])
        for i in range(y.shape[0]):
            f_truth[i, 0] = a * (b - y[i, 0])
            
        return f_truth  # shape (batch_size, state_size)

    # Diffusion
    def g(self, t, y):
        g_truth = torch.zeros(y.shape[0], y.shape[1], brownian_size)
        for i in range(y.shape[0]):
            g_truth[i, 0, 0] = sigma * y[i, 0]            
        return g_truth

neuralsde = neuralSDE()
sde = SDE()
y0 = torch.zeros(batch_size, state_size) # 200 * 1
for i in range(batch_size):
    y0[i, 0] = 2 + float(torch.randn(1)[0]) * 0.5 #csv_data[i][0]  2 + randn(1) * 0.5
    if y0[i, 0] <= 0:
        print("error")
        
ts = torch.linspace(0, 2, t_size)
# Initial state y0, the SDE is solved over the interval [ts[0], ts[-1]].
# ys will have shape (t_size, batch_size, state_size)


# generate new trajectories
ys_truth = torchsde.sdeint(sde, y0, ts)
    
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(sde.parameters(), lr=0.002,betas= (0.9, 0.999), weight_decay=0.005)
epoch = 1000
loss_list = []
ttime = time.time()
g_error_list = []
sigma_error_list = []

def quantile(samples_sorted):
    # Returns a function that computes the quantile of a given sample
    def quantile_func(p):
        return samples_sorted[torch.floor(p * len(samples_sorted)).long()]
    return quantile_func

def W22(u_samples, v_samples):
    # Adapted from https://github.com/nklb/wasserstein-distance
    u_samples_sorted, _ = u_samples.sort()
    v_samples_sorted, _ = v_samples.sort()
    u_icdf_grids = torch.linspace(0, 1, steps=len(u_samples))
    v_icdf_grids = torch.linspace(0, 1, steps=len(v_samples))
    grids = torch.unique(torch.cat((u_icdf_grids, v_icdf_grids))).sort()[0]
    U_icdf = quantile(u_samples_sorted)(grids[:-1])
    V_icdf = quantile(v_samples_sorted)(grids[:-1])
    return torch.sum((U_icdf - V_icdf) ** 2 * torch.diff(grids))


def dynamics_error(g_error_list, sigma_error_list):
    g_error_ep = []
    sigma_error_ep = []
    for i in range(t_size):
        g_pre = sde.f(0, ys_truth[i])
        g_truth = sde.f(0, ys_truth[i])
        sigma_pre = sde.g(0, ys_truth[i])
        sigma_truth = sde.g(0, ys_truth[i])
        g_error_ep.append(criterion(g_pre, g_truth).item() / batch_size)
        sigma_error_ep.append(criterion(sigma_pre, sigma_truth).item() / batch_size)
    
    g_error_list.append(g_error_ep)
    sigma_error_list.append(sigma_error_ep)



def W_2_distance(ys, ys_truth, i):

    if ys[i].shape == ys_truth[i].shape:
        ys_truth_sort = torch.zeros(ys_truth.shape[1], ys_truth.shape[2])
        bs = torch.tensor([float(ys[i][j][0]) for j in range(ys[i].shape[0])]).sort().indices
        ys_truth_re = torch.tensor([float(ys_truth[i][j][0]) for j in range(ys_truth[i].shape[0])])
        bs_truth = ys_truth_re.sort().indices
        ys_slice = torch.zeros(ys_truth.shape[1], ys_truth.shape[2])
        for j in range(ys_truth.shape[1]):
                #for r in range(ys_truth.shape[1]):
                #    if bs[j] == bs_truth[r]:
            ys_slice[j, 0] += ys[i, j, 0]
            ys_truth_sort[int(bs[j]),0] = ys_truth[i,int(bs_truth[j]),0]
            
        return ys_slice, ys_truth_sort
    else:
        ys_slice = torch.zeros(ys_truth.shape[1]*ys.shape[1], ys_truth.shape[2])
        ys_truth_slice = torch.zeros(ys_truth.shape[1]*ys.shape[1], ys_truth.shape[2])
        for j0 in range(int(ys.shape[1])):
            for i0 in range(int(ys_truth.shape[1])):
                    
                ys_truth_slice[(i0-1)*int(ys.shape[1]) + j0, 0] = ys_truth[i, i0, 0]
                ys_slice[(j0-1)*int(ys_truth[1]) + i0, 0] += ys_truth[i, j0, 0]
        
        bs = ys_slice.sort().indices
        ys_truth_re = torch.tensor([float(ys_truth[i][j][0]) for j in range(ys_truth[i].shape[0])])
        bs_truth = ys_truth_slice.sort().indices
        ys_truth_sort = torch.zeros(ys_truth.shape[1]*ys.shape[1], ys_truth.shape[2])
        for j in range(ys_truth_slice.shape[0]):
            ys_truth_sort[int(bs[j]),0] = ys_truth_slice[int(bs_truth[j]), 0]
        
        
        return ys_slice, ys_truth_sort
        
    

import pdb
for i0 in range(epoch):
    #print(i)
    ys = torchsde.sdeint(sde, y0, ts)   
    # loss = criterion(W_2_distance(ys, ys_truth, 0)[0], W_2_distance(ys, ys_truth, 0)[1])
    loss = 0
    for i in range(0, t_size):
        # loss += criterion(W_2_distance(ys, ys_truth, i)[0], W_2_distance(ys, ys_truth, i)[1])
        loss += W22(ys[i,:,0], ys_truth[i,:,0])
    if i0 % 10 == 0:
        print(i0, loss.item(), time.time() - ttime)
        dynamics_error(g_error_list, sigma_error_list)
        ttime = time.time()
        
    loss_list.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


    

truth_data = pd.DataFrame(data = [[float(ys_truth[i, j, 0]) for i in range(t_size)] for j in range(batch_size)])
truth_data.to_csv('ground_truth1.csv', header=False, index = False)

predict_data = pd.DataFrame(data = [[float(ys[i, j, 0]) for i in range(t_size)] for j in range(batch_size)])
predict_data.to_csv('predict1.csv', header=False, index = False)

loss_data = pd.DataFrame(data = loss_list)
loss_data.to_csv('loss1.csv', header=False, index = False)

g_error = pd.DataFrame(data = g_error_list)
g_error.to_csv('g_error1.csv', header=False, index = False)

sigma_error = pd.DataFrame(data = sigma_error_list)
sigma_error.to_csv('sigma_error1.csv', header=False, index = False)

torch.save(sde.state_dict(), 'CIX_model' +str(a) + '_' + str(b)+'.pkl')