"""
This file is used to analyze the results of the cix experiment.
Should be run as a jupyter notebook.

It creates sample paths for the ground truth and the neural SDE trained by W2
with a basis profile
"""

import torch 
torch.set_num_threads(1)
import torchsde
import argparse
import os
import IPython
import numpy as np
import pandas as pd
##### check if in juptyer notebook
def is_notebook():
    try:
        shell = IPython.get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
# import tqdm
from tqdm import tqdm
import importlib
from utils.sde_utils import neuralSDE, predict, neuralSDE_legacy
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
fpath = Path(mpl.get_data_path(), "fonts/ttf/cmr10.ttf")
##### Setting up the experiment #####

profile = 'base'
loss_function = 'W2'
repeat = 10
overwrite = False

loss_functions = [
    'W2',

]


##### Handling the Ground Truth #####
if os.path.exists('profiles/cix_truth_' + profile + '.py'):
    print('Loading ground truth profile: ' + profile + ' at ' + 'profiles/cix_truth_' + profile + '.py')
    cix_truth = importlib.import_module('profiles.cix_truth_' + profile)
else:
    # load base
    print('Loading ground truth base profile at profiles/cix_truth_base.py')
    cix_truth = importlib.import_module('profiles.cix_truth_base')
from sde.cix import SDE
sde = SDE(a=cix_truth.a, 
            b=cix_truth.b, 
            σ=cix_truth.σ)




print('Loading data...')
u_truth = torch.load(cix_truth.u_truth_savepath)
print('Data loaded.')

fig = plt.figure(layout='constrained', figsize=(12, 12))
subfigs = fig.subfigures(2, 2, wspace=0.1, hspace=0.1)
# aligned to the right
# subfig00: sample paths of u_truth and u_pred
# subfig01: two panels vertically stacked, for f(x) and g(x)
# subfig10: dependence of mse_f + mse_g on n_samples
# subfig11: dependence of mse_f + mse_g on sigmas
subfig00 = subfigs[0, 0]
subfig01 = subfigs[0, 1]
subfig10 = subfigs[1, 0]
subfig11 = subfigs[1, 1]


# check if u_truth contains nan
if torch.isnan(u_truth).any():
    raise Exception("u_truth contains nan.")

##### Handling the Neural SDE #####
if os.path.exists('profiles/cix_nsde_' + profile + '.py'):
    print('Loading NSDE profile: ' + profile + ' at ' + 'profiles/cix_nsde_' + profile + '.py')
    cix_nsde = importlib.import_module('profiles.cix_nsde_' + profile)
else:
    # load base
    print('Loading NSDE base profile at profiles/cix_nsde_base.py')
    cix_nsde = importlib.import_module('profiles.cix_nsde_base')

if not hasattr(cix_nsde, 'layers'):
    print("Using legacy neural SDE...")
    neuralsde = neuralSDE_legacy(cix_nsde.state_size,
                            cix_nsde.brownian_size,
                            cix_nsde.hidden_size,
                            cix_nsde.batch_size)
else:
    if not hasattr(cix_nsde, 'resnet'):
        cix_nsde.resnet = False
    # set up neural SDE
    neuralsde = neuralSDE(cix_nsde.state_size,
                            cix_nsde.brownian_size,
                            cix_nsde.hidden_size,
                            cix_nsde.batch_size,
                            layers=cix_nsde.layers,
                            resnet=cix_nsde.resnet,)
# set up optimizer
# optimizer = torch.optim.Adam(
#     neuralsde.parameters(), 
#     lr=cix_nsde.η,
#     betas=cix_nsde.β,
#     weight_decay=cix_nsde.weight_decay
# )



# model_savepaths = [f"models/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.pt"
#                     for loss_function in loss_functions]

# # check if there exists a trained model
# if not os.path.exists('models'):
#     os.makedirs('models')

# model_savepath = model_savepaths[0]
model_savepath = f"models/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.pt"
if os.path.exists(model_savepath) and not overwrite:
    print('Model exists, skip training...')
    neuralsde.load_state_dict(torch.load(model_savepath))
    print('Model loaded.')


##### Define callback function to inspect reconstructed dynamics and
# trajectory
# subfig00
ax = subfig00.subplots()



# u_truth is of shape (t_size, n_sample, state_size)
# convert it to n_sample different dataframes, each of shape (t_size, state_size), state_size = 1
# also save the ts
csv_u_truth = [u_truth[:, i, 0] for i in range(u_truth.shape[1])]
csv_ts = cix_truth.ts.detach().numpy()
# save the csvs


# save the final result
u_pred = predict(neuralsde, cix_truth.u0, cix_truth.ts)
# Convert the tensor to a numpy array
u_pred_np = u_pred.detach().numpy()
u_truth_np = u_truth.detach().numpy()
ts_np = cix_truth.ts.detach().numpy()
# Create a figure and a set of subplots
for i in range(u_pred_np.shape[1]):
    # Plot u_pred(t_size, batch, 0) as a function of t_size
    ax.plot(ts_np,u_pred_np[:, i, 0], color='red', alpha=0.1)
    ax.plot(ts_np,u_truth_np[:, i, 0], color='black', alpha=0.1)
ax.set_xlabel('t')
ax.set_ylabel('u')
ax.set_title('(a) Sample paths with 100 samples')
# set limits
ax.set_xlim(ts_np.min(), ts_np.max())


# save u_truth and u_pred to csv for each sample
for i in range(u_pred_np.shape[1]):
    df = pd.DataFrame({
        't': ts_np,
        # reshape the following to one dimension
        'u_truth': u_truth_np[:, i, 0].reshape(-1),
        'u_pred': u_pred_np[:, i, 0].reshape(-1)
    })
    df.to_csv(f"data/plots/cir_{cix_truth.truth_label}_sample_{i}.csv", index=False)

# subfig01
subfig01 = subfigs[0, 1]
axs = subfig01.subplots(2,1)
# plot f(x) and g(x)

xs = torch.linspace(0, 8, cix_nsde.batch_size)
fs = sde.f(0.0, xs[0:].unsqueeze(-1))
fs_pred = neuralsde.f(0.0, xs[0:].unsqueeze(-1))
xs_np = xs.detach().numpy()
fs_np = fs.detach().numpy()
fs_pred_np = fs_pred.detach().numpy()
ax = axs[0]
# set title to "Reconstructed f(x) and σ(x)"
ax.set_title("(b) Reconstructed f(x) and σ(x)")
ax.plot(xs_np, fs_np, color='black', alpha=0.5)
ax.plot(xs_np, fs_pred_np, color='red', alpha=0.5)
ax.set_xlabel('x')
ax.set_ylabel('f(x)')
ax.set_xlim(xs_np.min(), xs_np.max())

σs = torch.abs(sde.g(0.0, xs[0:].unsqueeze(-1)))
σs_pred = torch.abs(neuralsde.g(0.0, xs[0:].unsqueeze(-1)))
σs_np = np.squeeze(σs.detach().numpy())
σs_pred_np = np.squeeze(σs_pred.detach().numpy())
ax = axs[1]
ax.plot(xs_np, σs_np, color='black', alpha=0.5)
ax.plot(xs_np, σs_pred_np, color='red', alpha=0.5)
ax.set_xlabel('x')
ax.set_ylabel('g(x)')
ax.set_xlim(xs_np.min(), xs_np.max())

# write to csv
df = pd.DataFrame({
    'x': xs_np,
    # reshape the following to one dimension
    'f': fs_np.reshape(-1),
    'f_pred': fs_pred_np.reshape(-1),
    'g': σs_np.reshape(-1),
    'g_pred': σs_pred_np.reshape(-1)
})
df.to_csv(f"data/plots/cir_{cix_truth.truth_label}_fg.csv", index=False)

# subfig10
import seaborn as sns
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
markers = ['o', 's', '^', 'D', '*', 'p', 'v', '<', '>']

axs = subfig10.subplots(2,1)
# plot mse_f_mean and mse_σ_mean    
df_n_samples = pd.read_csv("output/cir_n_samples.csv")
colors = sns.color_palette("tab10", len(df_n_samples['loss_function'].unique()))

unique_loss_functions = df_n_samples['loss_function'].unique()
ax = axs[0]

for idx, loss_func in enumerate(df_n_samples['loss_function'].unique()):
    subset = df_n_samples[df_n_samples['loss_function'] == loss_func]
    ax.plot(subset['train_n_samples'], subset['mse_f_mean'], label=loss_func, 
             marker=markers[idx], color=colors[idx], markersize=6, linestyle='-')
    ax.fill_between(subset['train_n_samples'], 
                     subset['mse_f_mean'] - subset['mse_f_std'], 
                     subset['mse_f_mean'] + subset['mse_f_std'], color=colors[idx], alpha=0.2)

ax.set_title('(c) Comparison of Reconstruction Errors')
ax.set_ylabel('MSE in f')
# ax.legend()
# x log scale
ax.set_xscale('log')

ax = axs[1]
for idx, loss_func in enumerate(df_n_samples['loss_function'].unique()):
    subset = df_n_samples[df_n_samples['loss_function'] == loss_func]
    ax.plot(subset['train_n_samples'], subset['mse_σ_mean'], label=loss_func, 
             marker=markers[idx], color=colors[idx], markersize=6, linestyle='-')
    ax.fill_between(subset['train_n_samples'], 
                     subset['mse_σ_mean'] - subset['mse_σ_std'], 
                     subset['mse_σ_mean'] + subset['mse_σ_std'], color=colors[idx], alpha=0.2)
ax.set_xlabel('Number of Training Samples')
ax.set_ylabel('MSE in σ')
ax.legend()
# x log scale
ax.set_xscale('log')


# subfig11
axs = subfig11.subplots(2,1)
df_sigmas = pd.read_csv("output/cir_sigmas.csv")
unique_loss_functions = df_sigmas['loss_function'].unique()
ax = axs[0]
for idx, loss_func in enumerate(df_sigmas['loss_function'].unique()):
    subset = df_sigmas[df_sigmas['loss_function'] == loss_func]
    ax.plot(subset['σ'], subset['mse_f_mean'], label=loss_func, 
             marker=markers[idx], color=colors[idx], markersize=6, linestyle='-')
    ax.fill_between(subset['σ'], 
                     subset['mse_f_mean'] - subset['mse_f_std'], 
                     subset['mse_f_mean'] + subset['mse_f_std'], color=colors[idx], alpha=0.2)
ax.set_title('(d) Comparison of Reconstruction Errors')
# ax.set_xlabel('σ')
ax.set_ylabel('MSE in f')
# ax.legend()

ax = axs[1]
for idx, loss_func in enumerate(df_sigmas['loss_function'].unique()):
    subset = df_sigmas[df_sigmas['loss_function'] == loss_func]
    ax.plot(subset['σ'], subset['mse_σ_mean'], label=loss_func, 
             marker=markers[idx], color=colors[idx], markersize=6, linestyle='-')
    ax.fill_between(subset['σ'], 
                     subset['mse_σ_mean'] - subset['mse_σ_std'], 
                     subset['mse_σ_mean'] + subset['mse_σ_std'], color=colors[idx], alpha=0.2)
ax.set_xlabel('σ')
ax.set_ylabel('MSE in σ')
# ax.legend()


plt.show(fig)
# save the figure
fig.savefig(f"output/cir_plot.png")
fig.savefig(f"output/cir_plot.svg")
fig.savefig(f"output/cir_plot.pdf")
# print(f"Figure saved at figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.png.")
plt.close(fig)
