import torch 
torch.set_num_threads(1)
import torchsde
import argparse
import os
import IPython
##### check if in juptyer notebook
def is_notebook():
    try:
        shell = IPython.get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
# import tqdm
from tqdm import tqdm
import importlib
from utils.sde_utils import temporal_neuralSDE, predict, f,g, rel_err_f, rel_err_g
import matplotlib.pyplot as plt
import pandas as pd

# set tensor type to float64
torch.set_default_dtype(torch.float64)
##### Define callback function to inspect reconstructed dynamics and trajectory
import numpy as np
fpath = Path(mpl.get_data_path(), "fonts/ttf/cmr10.ttf")

def checkpoint(n_iter, iteration = True):
    # check n_sample and batch_size
    # pick minimum of the two
    _n_samples = min(temporal_OU_truth.n_samples, temporal_OU_nsde.batch_size)
    u0_check = u_truth[0, :_n_samples, :]
    u_truth_check = u_truth[:, :_n_samples, :]
    u_pred = predict(neuralsde, u0_check, temporal_OU_truth.ts, dt=0.1)
    # Convert the tensor to a numpy array
    u_pred_np = u_pred.detach().numpy()
    u_truth_np = u_truth_check.detach().numpy()
    ts_np = temporal_OU_truth.ts.detach().numpy()
    # Create a figure and a set of subplots
    fig, axs = plt.subplots(1, 1, figsize=(8, 8))

    t_grid, u_grid = torch.meshgrid(torch.linspace(ts_np.min(), ts_np.max(), 20),
                                    torch.linspace(u_truth_np.min(), u_truth_np.max(), 20),
                                    indexing='ij')
    t_values = t_grid.flatten()
    u_values = u_grid.flatten()

    f_vals = f(sde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()
    g_vals = g(sde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()

    f_pred_vals = f(neuralsde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()
    g_pred_vals = g(neuralsde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()
    # flatten the tensors
    f_values = f_vals.flatten()
    g_values = g_vals.flatten()
    f_pred_values = f_pred_vals.flatten()
    g_pred_values = g_pred_vals.flatten()
    # compute mse_f and mse_σ
    mse_f = rel_err_f(neuralsde, sde, u_truth_check, temporal_OU_truth.ts)
    mse_σ = rel_err_g(neuralsde, sde, u_truth_check, temporal_OU_truth.ts)
    # Iterate over each batch
    ax = axs
    ax.quiver(
        t_values.numpy(), 
        u_values.numpy(),  
        g_values,
        f_values, 
        angles='xy', 
        scale_units='xy', 
        scale=1, 
        color='black', 
        width=0.002)
    ax.quiver(
        t_values.numpy(), 
        u_values.numpy(),  
        g_pred_values,
        f_pred_values, 
        angles='xy', 
        scale_units='xy', 
        scale=1, 
        color='red', 
        width=0.002)

    for i in range(u_pred_np.shape[1]):
        # Plot u_pred(t_size, batch, 0) as a function of t_size
        ax.plot(ts_np,u_pred_np[:, i, 0], color='red', alpha=0.1)
        ax.plot(ts_np,u_truth_np[:, i, 0], color='black', alpha=0.1)
    ax.set_xlabel('t')
    ax.set_ylabel('u')
    if iteration:
        ax.set_title(f'Epoch {n_iter}')
    else:
        ax.set_title(f'{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}')
    # set limits
    ax.set_xlim(ts_np.min(), ts_np.max())
    # set y limits based on u_truth
    ax.set_ylim(u_truth_np.min(), u_truth_np.max())
    if iteration:
        if not os.path.exists('tmp'):
            os.makedirs('tmp')
        if not test_flag:
            fig.savefig(f"tmp/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}_sgd.png")
        # print(f"Figure saved at figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}sgd_.svg.")
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), f"tmp/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}_sgd.pt")
        # print(f"Model saved at tmp/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.pt.")
        return mse_f, mse_σ
    else:
        if not test_flag:
            fig.savefig(f"figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}sgd_.svg")
        # print(f"Figure saved at figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}sgd_.svg.")
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), model_savepath)
        # save data for plotting into multiple csv files
        for i in range(min(50, u_pred_np.shape[1])):
            df = pd.DataFrame({
                't': ts_np,
                'u_pred': u_pred_np[:,i,0],
                'u_truth': u_truth_np[:,i,0]
            })
            df.to_csv(f"data/plots/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_n_{i}_sgd.csv", index=False, header=True)
        # save data for vector field into a csv file
        df = pd.DataFrame({
                't': t_values.numpy(),
                'u': u_values.numpy(),
                'f': f_values,
                'g': g_values,
                'f_pred': f_pred_values,
                'g_pred': g_pred_values
            })
        df.to_csv(f"data/plots/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_vector_field_sgd.csv", index=False, header=True)

##### Setting up the experiment #####
if not is_notebook():
    # show info
    parser = argparse.ArgumentParser(description='Choose a profile.')
    # add profile argument
    parser.add_argument('--profile', 
                        type=str, 
                        required=False,
                        default='base',
                        help='The profile to use.')
    # add loss_function argument
    parser.add_argument('--loss_function',
                        type=str,
                        required=False,
                        default='W2',
                        help='The loss function to use.')
    # add repeat argument
    parser.add_argument('--repeat',
                        type=int,
                        required=False,
                        default=1,
                        help='The id of experiment repeat.')
    # add overwrite argument
    parser.add_argument('--overwrite',
                        type=bool,
                        # action='store_true',
                        required=False,
                        default=False,
                        help='Whether to overwrite the existing result.')
    # add  a test argument
    parser.add_argument('--test',
                        type = bool,
                        # action='store_true',
                        required=False,
                        default=False,
                        help='Whether to run the code in test mode.')
    args = parser.parse_args()
    profile = args.profile
    loss_function = args.loss_function
    repeat = args.repeat
    overwrite = args.overwrite
    test_flag = args.test
else:
    profile = 'n_samples_256'
    loss_function = 'MMD'
    repeat = 3
    overwrite = True
    test_flag = False


# print configuration
print('Profile: ' + profile)
print('Loss Function: ' + loss_function)
print('Repeat ID: ' + str(repeat))
print('Overwrite Flag: ' + str(overwrite))


##### Handling the Ground Truth #####
if os.path.exists('profiles/temporal_OU_truth_' + profile + '.py'):
    print('Loading ground truth profile: ' + profile + ' at ' + 'profiles/temporal_OU_truth_' + profile + '.py')
    temporal_OU_truth = importlib.import_module('profiles.temporal_OU_truth_' + profile)
else:
    # load base
    print('Loading ground truth base profile at profiles/temporal_OU_truth_base.py')
    temporal_OU_truth = importlib.import_module('profiles.temporal_OU_truth_base')
from sde.temporal_OU import SDE
sde = SDE(μ=temporal_OU_truth.μ, 
            θ=temporal_OU_truth.θ, 
            σ=temporal_OU_truth.σ)


if not os.path.exists('data'):
    os.makedirs('data')
if (not os.path.exists(temporal_OU_truth.u_truth_savepath)):
    print('Generating data...')
    u_truth = predict(sde, 
                              temporal_OU_truth.u0, 
                              temporal_OU_truth.ts,
                              dt = 0.1
                        )
    torch.save(u_truth, temporal_OU_truth.u_truth_savepath) # shape (t_size, n_samples, 1)
    # plot the data to a square publication quality figure
    fig, ax = plt.subplots(figsize=(4, 4))
    for i in range(u_truth.shape[1]):
        ax.plot(temporal_OU_truth.ts, u_truth[:,i,0], color='black', linewidth=0.1)
    ax.set_xlabel('$t$', fontsize=14)
    ax.set_ylabel('$u(t)$', fontsize=14)
    # set x lims and y lims
    ax.set_xlim([temporal_OU_truth.ts[0], temporal_OU_truth.ts[-1]])
    # add a note on the top left: n_samples = 400
    ax.text(0.05, 0.92, '$n_{samples}$ = ' + str(temporal_OU_truth.n_samples),
        transform=ax.transAxes, fontsize=14)

    # ax.grid()
    # ax.legend(fontsize=14)
    # fig.savefig("figures/temporal_OU_truth_" + temporal_OU_truth.truth_label + "sgd_.svg", dpi=300)
    fig.savefig("figures/temporal_OU_truth_" + temporal_OU_truth.truth_label + ".pdf", backend='pgf', dpi=300)
    
    print('Data generated.')
else:
    print('Loading data...')
    u_truth = torch.load(temporal_OU_truth.u_truth_savepath)
    print('Data loaded.')



# check if u_truth contains nan
if torch.isnan(u_truth).any():
    raise Exception("u_truth contains nan.")

##### Handling the Neural SDE #####
nsde_profile = "n_samples_base"
if os.path.exists('profiles/temporal_OU_nsde_' + nsde_profile + '.py'):
    print('Loading NSDE profile: ' + nsde_profile + ' at ' + 'profiles/temporal_OU_nsde_' + nsde_profile + '.py')
    temporal_OU_nsde = importlib.import_module('profiles.temporal_OU_nsde_' + nsde_profile)
else:
    # load base
    print('Loading NSDE base profile at profiles/temporal_OU_nsde_base.py')
    temporal_OU_nsde = importlib.import_module('profiles.temporal_OU_nsde_base')
# define loss
if loss_function == 'W2':
    from utils.sde_utils import W2_distance as distance
    def loss(u_pred, u_truth):
        loss_cul = 0
        for t in range(0, temporal_OU_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
        return loss_cul
elif loss_function == 'W2_alt':
    from utils.sde_utils import W2_loss_alt as loss
elif loss_function == 'W1':
    from utils.sde_utils import W1_distance as distance
    def loss(u_pred, u_truth):
        loss_cul = 0
        for t in range(0, temporal_OU_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
        return loss_cul
elif loss_function == "mean2_var":
    from utils.sde_utils import mean2_var_distance as distance
    def loss(u_pred, u_truth):
        loss_cul = 0
        for t in range(0, temporal_OU_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
        return loss_cul

elif loss_function == "mse":
    from utils.sde_utils import mse_distance as distance
    def loss(u_pred, u_truth):
        loss_cul = 0
        for t in range(0, temporal_OU_truth.t_size):
            loss_cul += distance(u_pred[t,:,0], u_truth[t,:,0])
        return loss_cul

elif loss_function == "apprx_loglik":
    from utils.sde_utils import apprx_loglik
    def loss(u_pred, u_truth):
        return -1 * apprx_loglik(u_truth, temporal_OU_truth.ts, neuralsde)

elif loss_function == "MMD":
    from utils.sde_utils import mmd_distance as distance
    def loss(u_pred, u_truth):
        loss_cul = 0
        for t in range(1, temporal_OU_truth.t_size):
            # skip the initial value. If the initial distribution is identical, leading to an NaN loss
            # check for nan
            if torch.isnan(distance(u_pred[t,:,:], u_truth[t,:,:])).any():
                print(t)
                raise Exception("u_pred contains nan.")
            loss_cul += distance(u_pred[t,:,:], u_truth[t,:,:])
        return loss_cul

else:
    raise Exception("Unknown loss function.")
    
# set torch seed according to repeat
torch.manual_seed(repeat*1000)


if loss_function == 'W2_mingtao_loss':
    from utils.sde_utils import W2_mingtao_loss as loss

# set up neural SDE
neuralsde = temporal_neuralSDE(temporal_OU_nsde.state_size,
                        temporal_OU_nsde.brownian_size,
                        temporal_OU_nsde.hidden_size,
                        temporal_OU_nsde.batch_size)

# normalization of neural sde parameters according to t_range provided by temporal_OU_truth.ts[-1]
with torch.no_grad():
    for name, param in neuralsde.named_parameters():
        # print(name)
        # mu.linear1.weight
        # mu.linear1.bias
        # mu.linear2.weight
        # mu.linear2.bias
        # sigma.linear1.weight
        # sigma.linear1.bias
        # sigma.linear2.weight
        # sigma.linear2.bias
        # only normalize wrt to mu
        if name[0:2] == 'mu':
            param.mul_(1/torch.sqrt(temporal_OU_truth.ts[-1]))


# set up optimizer
optimizer = torch.optim.Adam(
    neuralsde.parameters(), 
    lr=temporal_OU_nsde.η,
    betas=temporal_OU_nsde.β,
    weight_decay=temporal_OU_nsde.weight_decay
)


model_savepath = f"models/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_sgd.pt"

# check if there exists a trained model
if not os.path.exists('models'):
    os.makedirs('models')

if test_flag:
    print('Testing...')
    N_epoch = 1
else:
    N_epoch = temporal_OU_nsde.N_epoch

import time
import psutil
if os.path.exists(model_savepath) and not overwrite:
    print('Model exists, skip training...')
    neuralsde.load_state_dict(torch.load(model_savepath))
    print('Model loaded.')
else:
    # train the model
    print('Training...')
    ####### Train the model ########
    pbar = tqdm(range(N_epoch), desc='Training', leave=True)
    losses = []
    mse_fs = []
    mse_σs = []
    runtimes = []
    memory_usages = []
    time_0 = time.time()
    for n_iter in pbar:
        optimizer.zero_grad()
        random_ind = torch.randint(0, temporal_OU_truth.n_samples, (temporal_OU_nsde.batch_size,))
        u0_random = u_truth[0, random_ind, :]
        u_truth_random = u_truth[:, random_ind, :]
        u_pred = predict(neuralsde, u0_random, temporal_OU_truth.ts, dt=0.1)
        print(u_pred.shape)
        # print("get u_pred")
        current_loss = loss(u_pred, u_truth_random)
        # print("get current_loss")
        current_loss.backward()
        # print("get current_loss.backward()")
        optimizer.step()
        if n_iter % temporal_OU_nsde.checkpoint_freq == 0:
            print("\nPlotting and saving the checkpoint, may take a while...")
            checkpoint(n_iter)
            mse_f, mse_σ = checkpoint(n_iter)
            mse_fs.append(mse_f)
            mse_σs.append(mse_σ)
            time_1 = time.time()
            runtimes.append(time_1 - time_0)
            # acquire memory usage on CPU by current process
            process = psutil.Process(os.getpid())
            memory_usage = process.memory_info().rss
            # convert to MB
            memory_usage = memory_usage / 1024 / 1024
            memory_usages.append(memory_usage)
            time_0 = time_1
            losses.append(current_loss.item()/temporal_OU_truth.t_size)
            # if not test_flag:
            #     # save the model
            #     torch.save(neuralsde.state_dict(), model_savepath)

            # set postfix for tqdm
        postfix = dict(
            loss=current_loss.item(),
            mse_f=mse_f,
            mse_σ=mse_σ
        )
        # update tqdm
        pbar.set_postfix(**postfix)
        # append loss

    # save the final result
    checkpoint(N_epoch, iteration=False)

    ##### visualize and save the losses to csv
    losses = np.array(losses)
    mse_fs = np.array(mse_fs)
    mse_σs = np.array(mse_σs)
    # plot the losses
    fig, axs = plt.subplots(3, 1, figsize=(10, 15))
    ax = axs[0]
    # plot loss
    ax.plot(losses)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('loss')
    ax.set_title(f'{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}')
    ax.set_yscale('log')
    # plot mse_f
    ax = axs[1]
    ax.plot(mse_fs)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_f')
    # ax.set_title(f'{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}')
    ax.set_yscale('log')
    # plot mse_σ
    ax = axs[2]
    ax.plot(mse_σs)
    ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_σ')
    # ax.set_title(f'{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}')

    # plt.show()
    # save the figure
    if not test_flag:
        fig.savefig(f"figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses_sgd.png")
        print(f"Figure saved at figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses_sgd.png.")
    plt.close(fig)

    # save the losses, mse_fs, mse_σs to df to csv
    df = pd.DataFrame({
        'loss': losses,
        'mse_f': mse_fs,
        'mse_σ': mse_σs,
        'runtime': runtimes,
    })
    if not test_flag:
        df.to_csv(f"data/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses_sgd.csv", index=False, header=True)
        print(f"Data saved at data/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses_sgd.csv.")
    else:
        # save to test filename
        df.to_csv(f"data/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses_test_sgd.csv", index=False, header=True)


if test_flag:
    # profiling memory usage
    with torch.profiler.profile(profile_memory=True, record_shapes=True) as prof:
        optimizer.zero_grad()
        u_pred = predict(neuralsde, temporal_OU_truth.u0, temporal_OU_truth.ts, dt=0.1)
        # print("get u_pred")
        current_loss = loss(u_pred, u_truth)
        # print("get current_loss")
        current_loss.backward()
        # print("get current_loss.backward()")
        optimizer.step()

    table=prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10)
    import re
    def table_to_csv(table_string):
        # Split lines and filter out empty ones
        lines = [line for line in table_string.split("\n") if line.strip()]
        
        # Extract header
        header = re.split(r'\s{2,}', lines[1].strip())
        
        # Extract data using the positions from the header
        data_lines = lines[3:]
        data = [re.split(r'\s{2,}', line.strip()) for line in data_lines]
        
        # Construct CSV
        csv_lines = [",".join(header)]
        csv_lines.extend([",".join(line) for line in data])
        
        return "\n".join(csv_lines)

    csv_table = table_to_csv(table)
    csv_table
    import pandas as pd
    from io import StringIO

    # Use pandas to read the CSV string into a DataFrame
    prof_df = pd.read_csv(StringIO(csv_table))
    # remove 2 trailing rows
    prof_df = prof_df.iloc[:-2]
    # parse Self CPU Mem to float in Gb and Mb
    def parse_mem(mem):
        if mem[-2:] == 'Gb':
            return float(mem[:-2]) * 1024 / 8
        elif mem[-2:] == 'Mb':
            return float(mem[:-2]) / 8
        else:
            raise Exception("Unknown memory unit.")
    prof_df['Self CPU Mem'] = prof_df['Self CPU Mem'].apply(parse_mem)
    mem_use = sum(prof_df['Self CPU Mem'].values)
    df = pd.DataFrame({
        'loss': losses,
        'mse_f': mse_fs,
        'mse_σ': mse_σs,
        'runtime': runtimes,
        'memory_usage': memory_usages
    })
    # replace the memory usage column entrywise
    df['memory_usage'] = mem_use
    df.to_csv(f"data/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses_test_sgd.csv", index=False, header=True)

