import torch 
torch.set_num_threads(1)
import torchsde
import argparse
import os
import IPython
##### check if in juptyer notebook
def is_notebook():
    try:
        shell = IPython.get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
# import tqdm
from tqdm import tqdm
import importlib
from utils.sde_utils import neuralSDE, predict, f,g, rel_err_f, rel_err_Sigma, Sigma
import matplotlib.pyplot as plt
import pandas as pd

# set tensor type to float64
torch.set_default_dtype(torch.float64)
##### Define callback function to inspect reconstructed dynamics and trajectory
import numpy as np
fpath = Path(mpl.get_data_path(), "fonts/ttf/cmr10.ttf")



##### Setting up the experiment #####
if not is_notebook():
    # show info
    parser = argparse.ArgumentParser(description='Choose a profile.')
    # add profile argument
    parser.add_argument('--profile', 
                        type=str, 
                        required=False,
                        default='base',
                        help='The profile to use.')
    # add loss_function argument
    parser.add_argument('--loss_function',
                        type=str,
                        required=False,
                        default='W2',
                        help='The loss function to use.')
    # add repeat argument
    parser.add_argument('--repeat',
                        type=int,
                        required=False,
                        default=1,
                        help='The id of experiment repeat.')
    # add overwrite argument
    parser.add_argument('--overwrite',
                        type=bool,
                        # action='store_true',
                        required=False,
                        default=False,
                        help='Whether to overwrite the existing result.')
    # add  a test argument
    parser.add_argument('--test',
                        type = bool,
                        # action='store_true',
                        required=False,
                        default=False,
                        help='Whether to run the code in test mode.')
    args = parser.parse_args()
    profile = args.profile
    loss_function = args.loss_function
    repeat = args.repeat
    overwrite = args.overwrite
    test_flag = args.test
else:
    profile = 'n_samples_128'
    loss_function = 'W2_rotated'
    repeat = 21
    overwrite = True
    test_flag = True


# print configuration
print('Profile: ' + profile)
print('Loss Function: ' + loss_function)
print('Repeat ID: ' + str(repeat))
print('Overwrite Flag: ' + str(overwrite))


##### Handling the Ground Truth #####
if os.path.exists('profiles/gbm_truth_' + profile + '.py'):
    print('Loading ground truth profile: ' + profile + ' at ' + 'profiles/gbm_truth_' + profile + '.py')
    gbm_truth = importlib.import_module('profiles.gbm_truth_' + profile)
else:
    # load base
    print('Loading ground truth base profile at profiles/gbm_truth_base.py')
    gbm_truth = importlib.import_module('profiles.gbm_truth_base')
from sde.gbm import SDE
sde = SDE(gbm_truth.Sigma, gbm_truth.Mu)


if not os.path.exists('data'):
    os.makedirs('data')
if (not os.path.exists(gbm_truth.u_truth_savepath)) or overwrite:
    print('Generating data...')
    u_truth = predict(sde, 
                              gbm_truth.u0, 
                              gbm_truth.ts
                        )
    torch.save(u_truth, gbm_truth.u_truth_savepath) # shape (t_size, n_samples, 1)
    # plot the data to a square publication quality figure
    # for two-dimensional data, plot 2 subplots
    fig, axes = plt.subplots(2, 1, figsize=(8, 8))
    ax = axes[0]
    for i in range(u_truth.shape[1]):
        ax.plot(gbm_truth.ts, u_truth[:,i,0], color='black', linewidth=0.1)
    # ax.set_xlabel('$t$', fontsize=14)
    ax.set_ylabel('$x(t)$', fontsize=14)
    # set x lims and y lims
    ax.set_xlim([gbm_truth.ts[0], gbm_truth.ts[-1]])
    # add a note on the top left: n_samples = 400
    ax.text(0.05, 0.92, '$n_{samples}$ = ' + str(gbm_truth.n_samples),
        transform=ax.transAxes, fontsize=14)

    ax = axes[1]
    for i in range(u_truth.shape[1]):
        ax.plot(gbm_truth.ts, u_truth[:,i,1], color='black', linewidth=0.1)
    ax.set_xlabel('$t$', fontsize=14)
    ax.set_ylabel('$y(t)$', fontsize=14)
    # set x lims and y lims
    ax.set_xlim([gbm_truth.ts[0], gbm_truth.ts[-1]])
    # add a note on the top left: n_samples = 400
    # ax.text(0.05, 0.92, '$n_{samples}$ = ' + str(gbm_truth.n_samples),
    #     transform=ax.transAxes, fontsize=14)

    # ax.grid()
    # ax.legend(fontsize=14)
    fig.savefig("figures/gbm_truth_" + gbm_truth.truth_label + ".svg", dpi=300)
    fig.savefig("figures/gbm_truth_" + gbm_truth.truth_label + ".pdf", backend='pgf', dpi=300)
    
    print('Data generated.')
else:
    print('Loading data...')
    u_truth = torch.load(gbm_truth.u_truth_savepath)
    print('Data loaded.')



# check if u_truth contains nan
if torch.isnan(u_truth).any():
    raise Exception("u_truth contains nan.")

##### Handling the Neural SDE #####
if os.path.exists('profiles/gbm_nsde_' + profile + '.py'):
    print('Loading NSDE profile: ' + profile + ' at ' + 'profiles/gbm_nsde_' + profile + '.py')
    gbm_nsde = importlib.import_module('profiles.gbm_nsde_' + profile)
else:
    # load base
    print('Loading NSDE base profile at profiles/gbm_nsde_base.py')
    gbm_nsde = importlib.import_module('profiles.gbm_nsde_base')
# define loss
if loss_function == 'W2':
    from utils.sde_utils import W2_distance as distance
    def loss(u_pred):
        loss_cul = 0
        for t in range(0, gbm_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
            loss_cul += distance(u_pred[t,:,1], u_truth[t,:,1])
        return loss_cul
elif loss_function == "W2_rotated":
    from utils.sde_utils import W2_distance as distance
    from utils.sde_utils import rotate_2d_vector
    if not hasattr(gbm_nsde, 'n_rotate'):
        gbm_nsde.n_rotate = 9
    theta_degrees = torch.linspace(0, 90, gbm_nsde.n_rotate)
    # theta_degrees = theta_degrees[:-1] # 0 == 90
    def loss_rotate(u_pred, u_truth, theta_degree):
        loss_cul = 0
        for t in range(0, gbm_truth.t_size):
            # shape of u_pred[t,:,:] is (n_samples, 2)
            u_pred_rotated = rotate_2d_vector(u_pred[t,:,:], theta_degree)
            u_truth_rotated = rotate_2d_vector(u_truth[t,:,:], theta_degree)
            
            loss_cul += distance(u_pred_rotated[:,0], u_truth_rotated[:,0])
            loss_cul += distance(u_pred_rotated[:,1], u_truth_rotated[:,1])
        return loss_cul
    def loss(u_pred):
        # sum of loss_rotate over all theta_degrees
        loss_cul = 0
        for theta_degree in theta_degrees:
            loss_cul += loss_rotate(u_pred, u_truth, theta_degree)
        return loss_cul
elif loss_function == "W2_rotated_corrected":
    from utils.sde_utils import W2_distance as distance
    from utils.sde_utils import rotate_2d_vector
    if not hasattr(gbm_nsde, 'n_rotate'):
        gbm_nsde.n_rotate = 9
    theta_degrees = torch.linspace(0, 90, gbm_nsde.n_rotate)
    if not(theta_degrees[-1] == 0):
        theta_degrees = theta_degrees[:-1] # 0 == 90
    def loss_rotate(u_pred, u_truth, theta_degree):
        loss_cul = 0
        for t in range(0, gbm_truth.t_size):
            # shape of u_pred[t,:,:] is (n_samples, 2)
            u_pred_rotated = rotate_2d_vector(u_pred[t,:,:], theta_degree)
            u_truth_rotated = rotate_2d_vector(u_truth[t,:,:], theta_degree)
            
            loss_cul += distance(u_pred_rotated[:,0], u_truth_rotated[:,0])
            loss_cul += distance(u_pred_rotated[:,1], u_truth_rotated[:,1])
        return loss_cul
    def loss(u_pred):
        # sum of loss_rotate over all theta_degrees
        loss_cul = 0
        for theta_degree in theta_degrees:
            loss_cul += loss_rotate(u_pred, u_truth, theta_degree)
        return loss_cul

elif loss_function == "mse":
    from utils.sde_utils import mse_distance as distance
    def loss(u_pred):
        loss_cul = 0
        for t in range(0, gbm_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
            loss_cul += distance(u_pred[t,:,1], u_truth[t,:,1])
        return loss_cul

elif loss_function == "apprx_loglik":
    from utils.sde_utils import apprx_loglik
    def loss(u_pred):
        return -1 * apprx_loglik(u_truth, gbm_truth.ts, neuralsde)

elif loss_function == "MMD":
    from utils.sde_utils import mmd_distance as distance
    def loss(u_pred):
        loss_cul = 0
        for t in range(1, gbm_truth.t_size):
            # skip the initial value. If the initial distribution is identical, leading to an NaN loss
            # check for nan
            if torch.isnan(distance(u_pred[t,:,:], u_truth[t,:,:])).any():
                print(t)
                raise Exception("u_pred contains nan.")
            loss_cul += distance(u_pred[t,:,:], u_truth[t,:,:])
        return loss_cul

elif loss_function == "sliced_W2": 
    from utils.sde_utils import radially_sliced_W2_distance as distance
    def loss(u_pred):
        loss_cul = 0
        for t in range(0, gbm_truth.t_size):
            loss_cul += distance(u_pred[t], u_truth[t])
        return loss_cul

else:
    raise Exception("Unknown loss function.")
    
# set torch seed according to repeat
torch.manual_seed(repeat*1000)


# set up neural SDE
neuralsde = neuralSDE(gbm_nsde.state_size,
                        gbm_nsde.brownian_size,
                        gbm_nsde.hidden_size,
                        gbm_nsde.batch_size)

def checkpoint(n_iter, iteration = True):
    u_pred = predict(neuralsde, gbm_truth.u0, gbm_truth.ts)
    x_pred = u_pred[:,:,0]
    y_pred = u_pred[:,:,1]
    # Convert the tensor to a numpy array
    u_pred_np = u_pred.detach().numpy()
    u_truth_np = u_truth.detach().numpy()
    ts_np = gbm_truth.ts.detach().numpy()
    # Create a figure and a set of subplots
    # with layout: 2 columns, with first column having 2 rows
    # the first row showing x_pred and x_truth
    # the second row showing y_pred and y_truth
    # the second column having 2 rows, showing the vector field f(x,y) and g(x,y), respectively
    fig, ax = plt.subplots(1,1, figsize=(8, 8))
    # ax = axs[0,0]
    # for i in range(u_pred_np.shape[1]):
    #     # Plot u_pred(t_size, batch, 0) as a function of t_size
    #     ax.plot(ts_np,u_pred_np[:, i, 0], color='red', alpha=0.1)
    #     ax.plot(ts_np,u_truth_np[:, i, 0], color='black', alpha=0.1)
    # # set x limits
    # ax.set_xlim(ts_np.min(), ts_np.max())
    # ax.set_xlabel('t')
    # ax.set_ylabel('x')

    # ax = axs[1,0]
    # for i in range(u_pred_np.shape[1]):
    #     # Plot u_pred(t_size, batch, 0) as a function of t_size
    #     ax.plot(ts_np,u_pred_np[:, i, 1], color='red', alpha=0.1)
    #     ax.plot(ts_np,u_truth_np[:, i, 1], color='black', alpha=0.1)
    # # set x limits
    # ax.set_xlim(ts_np.min(), ts_np.max())
    # ax.set_xlabel('t')
    # ax.set_ylabel('y')

    # ax = axs[0,1]
    xmin = u_truth_np[:,:,0].min()
    xmax = u_truth_np[:,:,0].max()
    ymin = u_truth_np[:,:,1].min()
    ymax = u_truth_np[:,:,1].max()
    # plot the vector field
    x_grid, y_grid = torch.meshgrid(
        torch.linspace(xmin, xmax, 20),
        torch.linspace(ymin, ymax, 20),
        indexing='ij'
    )
    x_grid_values = x_grid.flatten()
    y_grid_values = y_grid.flatten()
    u_vec = torch.stack([x_grid_values, y_grid_values], dim=1)
    # compute the vector field
    f_vals = f(sde, u_vec)
    f_pred_vals = f(neuralsde, u_vec)
    # compute Sigma matrix field
    # Sigma_vals = Sigma(sde, u_vec)
    # plot the vector field
    ax.quiver(
        x_grid_values.detach().numpy(), y_grid_values.detach().numpy(),
        f_vals[:, 0].detach().numpy(), f_vals[:, 1].detach().numpy(),
        color='black', alpha=0.5
    )
    ax.quiver(
        x_grid_values.detach().numpy(), y_grid_values.detach().numpy(),
        f_pred_vals[:, 0].detach().numpy(), f_pred_vals[:, 1].detach().numpy(),
        color='red', alpha=0.5
    )
    # plot up to 50 trajectories
    for i in range(min(u_pred_np.shape[1], 50)):
        ax.plot(u_pred_np[:, i, 0], u_pred_np[:, i, 1], color='red', alpha=0.1)
        ax.plot(u_truth_np[:, i, 0], u_truth_np[:, i, 1], color='black', alpha=0.1)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    if iteration:
        ax.set_title(f'Epoch {n_iter}')
    else:
        ax.set_title(f'{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}')
    # set limits
    if iteration:
        if not os.path.exists('tmp'):
            os.makedirs('tmp')
        fig.savefig(f"tmp/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.png")
        # print(f"Figure saved at figures/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.svg.")
        mse_f = rel_err_f(neuralsde, sde, u_truth, ts=torch.tensor([0.0]))
        mse_σ = rel_err_Sigma(neuralsde, sde, u_truth, ts=torch.tensor([0.0]))
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), f"tmp/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.pt")
        # print(f"Model saved at tmp/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.pt.")
        return mse_f, mse_σ
    else:
        fig.savefig(f"figures/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}.svg")
        # print(f"Figure saved at figures/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}.svg.")
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), f"models/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}.pt")


# set up optimizer
optimizer = torch.optim.Adam(
    neuralsde.parameters(), 
    lr=gbm_nsde.η,
    betas=gbm_nsde.β,
    weight_decay=gbm_nsde.weight_decay
)

model_savepath = f"models/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}.pt"

# check if there exists a trained model
if not os.path.exists('models'):
    os.makedirs('models')

if test_flag:
    print('Testing...')
    N_epoch = 21
else:
    N_epoch = gbm_nsde.N_epoch

import time
import psutil
if os.path.exists(model_savepath) and not overwrite:
    print('Model exists, skip training...')
    neuralsde.load_state_dict(torch.load(model_savepath))
    print('Model loaded.')
    checkpoint(N_epoch, iteration=False)

else:
    # train the model
    print('Training...')
    ####### Train the model ########
    pbar = tqdm(range(N_epoch), desc='Training', leave=True)
    losses = []
    mse_fs = []
    mse_σs = []
    runtimes = []
    memory_usages = []
    time_0 = time.time()
    for n_iter in pbar:
        optimizer.zero_grad()
        u_pred = predict(neuralsde, gbm_truth.u0, gbm_truth.ts, dt=0.1)
        # print("get u_pred")
        current_loss = loss(u_pred)
        # print("get current_loss")
        current_loss.backward()
        # print("get current_loss.backward()")
        optimizer.step()
        if n_iter % gbm_nsde.checkpoint_freq == 0:
            print("\nPlotting and saving the checkpoint, may take a while...")
            # checkpoint(n_iter)
            mse_f, mse_σ = checkpoint(n_iter)
            mse_fs.append(mse_f)
            mse_σs.append(mse_σ)
            time_1 = time.time()
            runtimes.append(time_1 - time_0)
            # acquire memory usage on CPU by current process
            process = psutil.Process(os.getpid())
            memory_usage = process.memory_info().rss
            # convert to MB
            memory_usage = memory_usage / 1024 / 1024
            memory_usages.append(memory_usage)
            time_0 = time_1
            losses.append(current_loss.item()/gbm_truth.t_size)

            torch.save(neuralsde.state_dict(), model_savepath)

            # set postfix for tqdm
        postfix = dict(
            loss=current_loss.item(),
            mse_f=mse_fs[-1],
            mse_σ=mse_σs[-1],
        )
        # update tqdm
        pbar.set_postfix(**postfix)
        # append loss

    # save the final result
    checkpoint(N_epoch, iteration=False)

    ##### visualize and save the losses to csv
    losses = np.array(losses)
    mse_fs = np.array(mse_fs)
    mse_σs = np.array(mse_σs)
    # plot the losses
    fig, axs = plt.subplots(3, 1, figsize=(10, 15))
    ax = axs[0]
    # plot loss
    ax.plot(losses)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('loss')
    ax.set_title(f'{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}')
    ax.set_yscale('log')
    # plot mse_f
    ax = axs[1]
    ax.plot(mse_fs)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_f')
    # ax.set_title(f'{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}')
    ax.set_yscale('log')
    # plot mse_σ
    ax = axs[2]
    ax.plot(mse_σs)
    ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_σ')
    # ax.set_title(f'{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}')

    # plt.show()
    # save the figure
    fig.savefig(f"figures/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.png")
    print(f"Figure saved at figures/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.png.")
    plt.close(fig)

    # save the losses, mse_fs, mse_σs to df to csv
    df = pd.DataFrame({
        'loss': losses,
        'mse_f': mse_fs,
        'mse_σ': mse_σs,
        'runtime': runtimes,
        'memory_usage': memory_usages
    })
    df.to_csv(f"data/{gbm_nsde.nsde_label}_{gbm_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.csv", index=False, header=True)


