import torch 
torch.set_num_threads(1)
import torchsde
import argparse
import os
import IPython
##### check if in juptyer notebook
def is_notebook():
    try:
        shell = IPython.get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
# import tqdm
from tqdm import tqdm
import importlib
from utils.sde_utils import neuralSDE, predict, f,g, rel_err_f, rel_err_g, neuralSDE_legacy
import matplotlib.pyplot as plt
import pandas as pd

# set tensor type to float64
torch.set_default_dtype(torch.float64)
##### Define callback function to inspect reconstructed dynamics and trajectory
import numpy as np

def checkpoint(n_iter, iteration = True):
    u_pred = predict(neuralsde, cix_truth.u0, cix_truth.ts)
    # Convert the tensor to a numpy array
    u_pred_np = u_pred.detach().numpy()
    u_truth_np = u_truth.detach().numpy()
    ts_np = cix_truth.ts.detach().numpy()
    # Create a figure and a set of subplots
    fig, axs = plt.subplots(3, 1, figsize=(10, 15))

    # Iterate over each batch
    ax = axs[0]
    for i in range(u_pred_np.shape[1]):
        # Plot u_pred(t_size, batch, 0) as a function of t_size
        ax.plot(ts_np,u_pred_np[:, i, 0], color='red', alpha=0.1)
        ax.plot(ts_np,u_truth_np[:, i, 0], color='black', alpha=0.1)
    ax.set_xlabel('t')
    ax.set_ylabel('u')
    if iteration:
        ax.set_title(f'Epoch {n_iter}')
    else:
        ax.set_title(f'{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}')
    # set limits
    ax.set_xlim(ts_np.min(), ts_np.max())

    xs = torch.linspace(0, 10, cix_nsde.batch_size)
    fs = f(sde, xs)
    fs_pred = f(neuralsde, xs)
    xs_np = xs.detach().numpy()
    fs_np = fs.detach().numpy()
    fs_pred_np = fs_pred.detach().numpy()
    ax = axs[1]
    ax.plot(xs_np, fs_np, color='black', alpha=0.5)
    ax.plot(xs_np, fs_pred_np, color='red', alpha=0.5)
    ax.set_xlabel('x')
    ax.set_ylabel('f(x)')
    ax.set_xlim(xs_np.min(), xs_np.max())

    σs = torch.abs(g(sde, xs))
    σs_pred = torch.abs(g(neuralsde, xs))
    σs_np = np.squeeze(σs.detach().numpy())
    σs_pred_np = np.squeeze(σs_pred.detach().numpy())
    ax = axs[2]
    ax.plot(xs_np, σs_np, color='black', alpha=0.5)
    ax.plot(xs_np, σs_pred_np, color='red', alpha=0.5)
    ax.set_xlabel('x')
    ax.set_ylabel('g(x)')
    ax.set_xlim(xs_np.min(), xs_np.max())
    xs = torch.linspace(2, 6, cix_nsde.batch_size)
    fs = f(sde, xs)
    fs_pred = f(neuralsde, xs)
    σs = g(sde, xs)
    σs_pred = g(neuralsde, xs)
    xs_np = xs.detach().numpy()
    fs_np = fs.detach().numpy()
    fs_pred_np = fs_pred.detach().numpy()
    σs_np = σs.detach().numpy()
    σs_pred_np = σs_pred.detach().numpy()
    # evaluate the mse between fs_np and fs_pred_np, σs_np and σs_pred_np
    mse_f = rel_err_f(neuralsde, sde, u_truth, cix_truth.ts)
    mse_σ = rel_err_g(neuralsde, sde, u_truth, cix_truth.ts)
    # plt.show()
    # save the figure

    if iteration:
        if not os.path.exists('tmp'):
            os.makedirs('tmp')
        fig.savefig(f"tmp/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.png")
        # print(f"Figure saved at figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.svg.")
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), f"tmp/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.pt")
        # print(f"Model saved at tmp/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.pt.")
        return mse_f, mse_σ
    else:
        fig.savefig(f"figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.svg")
        # print(f"Figure saved at figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.svg.")
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), f"models/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.pt")
        t_grid, u_grid = torch.meshgrid(torch.linspace(ts_np.min(), ts_np.max(), 20),
                                torch.linspace(u_truth_np.min(), u_truth_np.max(), 20),
                                indexing='ij')
        t_values = t_grid.flatten()
        u_values = u_grid.flatten()

        f_vals = f(sde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()
        g_vals = g(sde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()

        f_pred_vals = f(neuralsde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()
        g_pred_vals = g(neuralsde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()
        # flatten the tensors
        f_values = f_vals.flatten()
        g_values = g_vals.flatten()
        f_pred_values = f_pred_vals.flatten()
        g_pred_values = g_pred_vals.flatten()

        for i in range(min(50, u_pred_np.shape[1])):
            df = pd.DataFrame({
                't': ts_np,
                'u_pred': u_pred_np[:,i,0],
                'u_truth': u_truth_np[:,i,0]
            })
            df.to_csv(f"data/plots/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_n_{i}.csv", index=False, header=True)
        # save data for vector field into a csv file
        df = pd.DataFrame({
                't': t_values.numpy(),
                'u': u_values.numpy(),
                'f': f_values,
                'g': g_values,
                'f_pred': f_pred_values,
                'g_pred': g_pred_values
            })
        df.to_csv(f"data/plots/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_vector_field.csv", index=False, header=True)

fpath = Path(mpl.get_data_path(), "fonts/ttf/cmr10.ttf")
##### Setting up the experiment #####
if not is_notebook():
    # show info
    parser = argparse.ArgumentParser(description='Choose a profile.')
    # add profile argument
    parser.add_argument('--profile', 
                        type=str, 
                        required=False,
                        default='base',
                        help='The profile to use.')
    # add loss_function argument
    parser.add_argument('--loss_function',
                        type=str,
                        required=False,
                        default='W2',
                        help='The loss function to use.')
    # add repeat argument
    parser.add_argument('--repeat',
                        type=int,
                        required=False,
                        default=1,
                        help='The id of experiment repeat.')
    # add overwrite argument
    parser.add_argument('--overwrite',
                        type=bool,
                        required=False,
                        default=False,
                        help='Whether to overwrite the existing result.')
    # add test argument
    parser.add_argument('--test',
                        type=bool,
                        required=False,
                        default=False,
                        help='Whether to run the code in test mode.')
    # add layers argument
    parser.add_argument('--layers',
                        type=int,
                        required=False,
                        default=2,
                        help='The number of layers in the neural network.')
    args = parser.parse_args()
    profile = args.profile
    loss_function = args.loss_function
    repeat = args.repeat
    overwrite = args.overwrite
    test_flag = args.test
    # layers = args.layers
else:
    profile = 'n_samples_200'
    loss_function = 'W2'
    repeat = 10
    overwrite = False
    test_flag = False
    # layers = 3


# print configuration
print('Profile: ' + profile)
print('Loss Function: ' + loss_function)
print('Repeat ID: ' + str(repeat))
print('Overwrite Flag: ' + str(overwrite))


##### Handling the Ground Truth #####
if os.path.exists('profiles/cix_truth_' + profile + '.py'):
    print('Loading ground truth profile: ' + profile + ' at ' + 'profiles/cix_truth_' + profile + '.py')
    cix_truth = importlib.import_module('profiles.cix_truth_' + profile)
else:
    # load base
    print('Loading ground truth base profile at profiles/cix_truth_base.py')
    cix_truth = importlib.import_module('profiles.cix_truth_base')
from sde.cix import SDE
sde = SDE(a=cix_truth.a, 
            b=cix_truth.b, 
            σ=cix_truth.σ)

ts = cix_truth.ts
if not os.path.exists('data'):
    os.makedirs('data')
if (not os.path.exists(cix_truth.u_truth_savepath)) or overwrite:
    print('Generating data...')
    u_truth = predict(sde, 
                              cix_truth.u0, 
                              cix_truth.ts)
    torch.save(u_truth, cix_truth.u_truth_savepath) # shape (t_size, n_samples, 1)
    # plot the data to a square publication quality figure
    fig, ax = plt.subplots(figsize=(4, 4))
    for i in range(u_truth.shape[1]):
        ax.plot(cix_truth.ts, u_truth[:,i,0], color='black', linewidth=0.1)
    ax.set_xlabel('$t$', fontsize=14)
    ax.set_ylabel('$u(t)$', fontsize=14)
    # set x lims and y lims
    ax.set_xlim([0, 2])
    # add a note on the top left: n_samples = 400
    ax.text(0.05, 0.95, '$n_{samples}$ = ' + str(cix_truth.n_samples),
        transform=ax.transAxes, fontsize=14)

    # ax.grid()
    # ax.legend(fontsize=14)
    fig.savefig("figures/cix_truth_" + cix_truth.truth_label + ".svg", dpi=300)
    fig.savefig("figures/cix_truth_" + cix_truth.truth_label + ".pdf", backend='pgf', dpi=300)
    
    print('Data generated.')
else:
    print('Loading data...')
    u_truth = torch.load(cix_truth.u_truth_savepath)
    print('Data loaded.')



# check if u_truth contains nan
if torch.isnan(u_truth).any():
    print("u_truth contains nan.")
    exit()

##### Handling the Neural SDE #####
if os.path.exists('profiles/cix_nsde_' + profile + '.py'):
    print('Loading NSDE profile: ' + profile + ' at ' + 'profiles/cix_nsde_' + profile + '.py')
    cix_nsde = importlib.import_module('profiles.cix_nsde_' + profile)
else:
    # load base
    print('Loading NSDE base profile at profiles/cix_nsde_base.py')
    cix_nsde = importlib.import_module('profiles.cix_nsde_base')
# define loss
if loss_function == 'W2':
    from utils.sde_utils import W2_distance as distance
    def loss(u_pred):
        loss_cul = 0
        for t in range(0, cix_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
        return loss_cul
elif loss_function == 'W2_alt':
    from utils.sde_utils import W2_loss_alt as loss
elif loss_function == 'W1':
    from utils.sde_utils import W1_distance as distance
    def loss(u_pred):
        loss_cul = 0
        for t in range(0, cix_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
        return loss_cul
elif loss_function == "mean2_var":
    from utils.sde_utils import mean2_var_distance as distance
    def loss(u_pred):
        loss_cul = 0
        for t in range(0, cix_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
        return loss_cul

elif loss_function == "mse":
    from utils.sde_utils import mse_distance as distance
    def loss(u_pred):
        loss_cul = 0
        for t in range(0, cix_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
        return loss_cul

elif loss_function == "apprx_loglik":
    from utils.sde_utils import apprx_loglik
    def loss(u_pred):
        return -1 * apprx_loglik(u_truth, cix_truth.ts, neuralsde)

elif loss_function == "MMD":
    from utils.sde_utils import mmd_distance as distance
    def loss(u_pred):
        loss_cul = 0
        for t in range(1, cix_truth.t_size):
            # skip the initial value. If the initial distribution is identical, leading to an NaN loss
            # check for nan
            if torch.isnan(distance(u_pred[t,:,:], u_truth[t,:,:])).any():
                print(t)
                raise Exception("u_pred contains nan.")
            loss_cul += distance(u_pred[t,:,:], u_truth[t,:,:])
        return loss_cul

else:
    raise Exception("Unknown loss function.")
    
# set torch seed according to repeat
torch.manual_seed(repeat*1000)


if loss_function == 'W2_mingtao_loss':
    from utils.sde_utils import W2_mingtao_loss as loss

# if cix_nsde.layers is not defined, set it to 2
if not hasattr(cix_nsde, 'layers'):
    print("Using legacy neural SDE...")
    neuralsde = neuralSDE_legacy(cix_nsde.state_size,
                            cix_nsde.brownian_size,
                            cix_nsde.hidden_size,
                            cix_nsde.batch_size)
else:
    if not hasattr(cix_nsde, 'resnet'):
        cix_nsde.resnet = False
    # set up neural SDE
    neuralsde = neuralSDE(cix_nsde.state_size,
                            cix_nsde.brownian_size,
                            cix_nsde.hidden_size,
                            cix_nsde.batch_size,
                            layers=cix_nsde.layers,
                            resnet=cix_nsde.resnet,)


# set up optimizer
optimizer = torch.optim.Adam(
    neuralsde.parameters(), 
    lr=cix_nsde.η,
    betas=cix_nsde.β,
    weight_decay=cix_nsde.weight_decay
)

model_savepath = f"models/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.pt"

# check if there exists a trained model
if not os.path.exists('models'):
    os.makedirs('models')

if test_flag:
    print('Testing...')
    N_epoch = 11
else:
    N_epoch = cix_nsde.N_epoch

if os.path.exists(model_savepath) and not overwrite:
    print('Model exists, skip training...')
    neuralsde.load_state_dict(torch.load(model_savepath))
    print('Model loaded.')
    # checkpoint(N_epoch, iteration=False)
else:
    # train the model
    print('Training...')
    ####### Train the model ########
    pbar = tqdm(range(N_epoch), desc='Training', leave=True)
    losses = []
    mse_fs = []
    mse_σs = []
    for n_iter in pbar:
        optimizer.zero_grad()
        u_pred = predict(neuralsde, cix_truth.u0, cix_truth.ts)
        current_loss = loss(u_pred)
        current_loss.backward()
        optimizer.step()
        if n_iter % cix_nsde.checkpoint_freq == 0:
            mse_f, mse_σ = checkpoint(n_iter)
            mse_fs.append(mse_f)
            mse_σs.append(mse_σ)
            losses.append(current_loss.item()/cix_truth.t_size)
            # torch.save(neuralsde.state_dict(), model_savepath)
            # set postfix for tqdm
            postfix = dict(
                loss=current_loss.item(),
                mse_f=mse_f,
                mse_σ=mse_σ
            )
            # update tqdm
            pbar.set_postfix(**postfix)
            # append loss

    # save the final result
    checkpoint(N_epoch, iteration=False)

    ##### visualize and save the losses to csv
    losses = np.array(losses)
    mse_fs = np.array(mse_fs)
    mse_σs = np.array(mse_σs)
    # plot the losses
    fig, axs = plt.subplots(3, 1, figsize=(10, 15))
    ax = axs[0]
    # plot loss
    ax.plot(losses)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('loss')
    ax.set_title(f'{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}')
    ax.set_yscale('log')
    # plot mse_f
    ax = axs[1]
    ax.plot(mse_fs)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_f')
    # ax.set_title(f'{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}')
    ax.set_yscale('log')
    # plot mse_σ
    ax = axs[2]
    ax.plot(mse_σs)
    ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_σ')
    # ax.set_title(f'{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}')

    # plt.show()
    # save the figure
    fig.savefig(f"figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.png")
    print(f"Figure saved at figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.png.")
    plt.close(fig)

    # save the losses, mse_fs, mse_σs to df to csv
    df = pd.DataFrame({
        'loss': losses,
        'mse_f': mse_fs,
        'mse_σ': mse_σs
    })
    df.to_csv(f"data/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.csv", index=False, header=True)

checkpoint(N_epoch, iteration=False)


####### Model Testing ########
if os.path.exists(f"profiles/cix_test_{profile}.py"):
    print(f"Loading testing profile {profile}...")
    cix_test = importlib.import_module('profiles.cix_test_' + profile)
else:
    # load the base profile
    print(f"Loading testing base profile at profiles/cix_test_base.py...")
    cix_test = importlib.import_module('profiles.cix_test_base')
# expecting to have different n_samples collected as N_samples
# and different u0 corresponding to each n_sample as u0s
N_samples = cix_test.N_samples
u0s = cix_test.u0s
# load the model
neuralsde.load_state_dict(torch.load(f"models/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.pt"))

# test_truth
u_tests = [
    predict(sde, u0s[i], cix_test.ts) for i in range(len(N_samples))
]
# test_pred
u_preds = [
    predict(neuralsde, u0s[i], cix_test.ts) for i in range(len(N_samples))
]
# test_errors: can only be defined for non-mse loss functions
if loss_function != 'mse':
    test_errors = [
        loss(u_tests[i]).detach().numpy() for i in range(len(N_samples))
    ]
else:# return nan
    test_errors = [
        np.nan for i in range(len(N_samples))
    ]

xs = torch.linspace(2, 6, 100)
# mse_f 
fs = f(sde, xs).detach().numpy()
fs_pred = f(neuralsde, xs).detach().numpy()
mse_f = rel_err_f(neuralsde, sde, u_truth, cix_truth.ts)
# mse_σ
σs = g(sde, xs).detach().numpy()
σs_pred = g(neuralsde, xs).detach().numpy()
mse_σ = rel_err_g(neuralsde, sde, u_truth, cix_truth.ts)
# re-evaulate the train_error
train_error = loss(predict(neuralsde, cix_truth.u0, cix_truth.ts)).detach().numpy()
# save as dataframe, each row: cix_truth.n_samples, test n_samples, test_error, loss_function
df = pd.DataFrame({
    'id': [repeat] * len(N_samples),
    'train_n_samples': [cix_truth.n_samples] * len(N_samples),
    'test_n_samples': N_samples,
    'test_error': test_errors,
    'train_error': [train_error] * len(N_samples),
    'mse_f': [mse_f] * len(N_samples),
    'mse_σ': [mse_σ] * len(N_samples),
    'loss_function': [loss_function] * len(N_samples)
})
# save the dataframe to  data/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_test_errors.csv
df.to_csv(f"data/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_test_errors.csv", index=False)