"""
This file is used to analyze the results of the example2d experiment.
Should be run as a jupyter notebook.

It creates sample paths for the ground truth and the neural SDE trained by W2
with a basis profile
"""

import torch 
torch.set_num_threads(1)
import torchsde
import argparse
import os
import IPython
##### check if in juptyer notebook
def is_notebook():
    try:
        shell = IPython.get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
# import tqdm
from tqdm import tqdm
import importlib
from utils.sde_utils import neuralSDE, predict, f,g, rel_err_f, rel_err_Sigma, Sigma
import matplotlib.pyplot as plt
import pandas as pd

# set tensor type to float64
torch.set_default_dtype(torch.float64)
##### Define callback function to inspect reconstructed dynamics and trajectory
import numpy as np
fpath = Path(mpl.get_data_path(), "fonts/ttf/cmr10.ttf")
##### Setting up the experiment #####

profile = 'n_samples_256'
loss_function = 'W2_rotated'
repeat = 10
overwrite = False



if os.path.exists('profiles/example2d_truth_' + profile + '.py'):
    print('Loading ground truth profile: ' + profile + ' at ' + 'profiles/example2d_truth_' + profile + '.py')
    example2d_truth = importlib.import_module('profiles.example2d_truth_' + profile)
else:
    # load base
    print('Loading ground truth base profile at profiles/example2d_truth_base.py')
    example2d_truth = importlib.import_module('profiles.example2d_truth_base')
from sde.example2d import SDE
sde = SDE(example2d_truth.Sigma, example2d_truth.mu1, example2d_truth.mu2)



print('Loading data...')
u_truth = torch.load(example2d_truth.u_truth_savepath)
print('Data loaded.')

##### Handling the Neural SDE #####
if os.path.exists('profiles/example2d_nsde_' + profile + '.py'):
    print('Loading NSDE profile: ' + profile + ' at ' + 'profiles/example2d_nsde_' + profile + '.py')
    example2d_nsde = importlib.import_module('profiles.example2d_nsde_' + profile)
else:
    # load base
    print('Loading NSDE base profile at profiles/example2d_nsde_base.py')
    example2d_nsde = importlib.import_module('profiles.example2d_nsde_base')

neuralsde = neuralSDE(example2d_nsde.state_size,
                        example2d_nsde.brownian_size,
                        example2d_nsde.hidden_size,
                        example2d_nsde.batch_size)

model_savepath = f"models/{example2d_nsde.nsde_label}_{example2d_truth.truth_label}_{loss_function}_repeat_{repeat}.pt"
if os.path.exists(model_savepath) and not overwrite:
    print('Model exists, skip training...')
    neuralsde.load_state_dict(torch.load(model_savepath))
    print('Model loaded.')
else:
    raise ValueError("Model does not exist. Please train the model first.")

##########################################################
##########################################################
###################### Trajectories ######################
##########################################################
##########################################################




# save the final result
u_pred = predict(neuralsde, example2d_truth.u0, example2d_truth.ts)
# Convert the tensor to a numpy array
u_pred_np = u_pred.detach().numpy()
u_truth_np = u_truth.detach().numpy()
ts_np = example2d_truth.ts.detach().numpy()
csv_x_pred = u_pred_np[:, :, 0]
csv_y_pred = u_pred_np[:, :, 1]
csv_ts = example2d_truth.ts.detach().numpy()
csv_x_truth = [u_truth[:, i, 0] for i in range(u_truth.shape[1])]
csv_y_truth = [u_truth[:, i, 1] for i in range(u_truth.shape[1])]
csv_ts = example2d_truth.ts.detach().numpy()

# save the csvs
import pandas as pd
for i in range(len(csv_x_pred)):
    df = pd.DataFrame({
        't': csv_ts,
        'x_pred': csv_x_pred[:,i],
        'y_pred': csv_y_pred[:,i],
        'x_truth': csv_x_truth[i],
        'y_truth': csv_y_truth[i]
    })
    df.to_csv(f"data/plots/example2d_{example2d_truth.truth_label}_sample_{i}.csv", index=False)

##########################################################
##########################################################
###################### Vector Field ######################
##########################################################
##########################################################
xmin = u_truth_np[:,:,0].min()
xmax = u_truth_np[:,:,0].max()
ymin = u_truth_np[:,:,1].min()
ymax = u_truth_np[:,:,1].max()
# plot the vector field
x_grid, y_grid = torch.meshgrid(
    torch.linspace(xmin, xmax, 20),
    torch.linspace(ymin, ymax, 20),
    indexing='ij'
)
x_grid_values = x_grid.flatten()
y_grid_values = y_grid.flatten()
u_vec = torch.stack([x_grid_values, y_grid_values], dim=1)
# compute the vector field
f_vals = f(sde, u_vec)
f_pred_vals = f(neuralsde, u_vec)
# save to csv for PGFPlots
df = pd.DataFrame({
    'x': x_grid_values.detach().numpy(),
    'y': y_grid_values.detach().numpy(),
    'f_x': f_vals[:,0].detach().numpy(),
    'f_y': f_vals[:,1].detach().numpy(),
    'f_pred_x': f_pred_vals[:,0].detach().numpy(),
    'f_pred_y': f_pred_vals[:,1].detach().numpy()
})

df.to_csv(f"data/plots/example2d_{example2d_truth.truth_label}_vector_field.csv", index=False)

##########################################################