import torch 
torch.set_num_threads(1)
import torchsde
import torch.optim.swa_utils as swa_utils
import argparse
import os
import IPython
from utils.wgan_ncde_utils import Discriminator

##### check if in juptyer notebook
def is_notebook():
    try:
        shell = IPython.get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
# import tqdm
from tqdm import tqdm
import importlib
from utils.sde_utils import temporal_neuralSDE, predict, f,g, rel_err_f, rel_err_g
import matplotlib.pyplot as plt
import pandas as pd

# set tensor type to float64
torch.set_default_dtype(torch.float64)
##### Define callback function to inspect reconstructed dynamics and trajectory
import numpy as np
fpath = Path(mpl.get_data_path(), "fonts/ttf/cmr10.ttf")

def checkpoint(n_iter, iteration = True):
    u_pred = predict(neuralsde, temporal_OU_truth.u0, temporal_OU_truth.ts, dt=0.1)
    # Convert the tensor to a numpy array
    u_pred_np = u_pred.detach().numpy()
    u_truth_np = u_truth.detach().numpy()
    ts_np = temporal_OU_truth.ts.detach().numpy()
    # Create a figure and a set of subplots
    fig, axs = plt.subplots(1, 1, figsize=(8, 8))

    t_grid, u_grid = torch.meshgrid(torch.linspace(ts_np.min(), ts_np.max(), 20),
                                    torch.linspace(u_truth_np.min(), u_truth_np.max(), 20),
                                    indexing='ij')
    t_values = t_grid.flatten()
    u_values = u_grid.flatten()

    f_vals = f(sde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()
    g_vals = g(sde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()

    f_pred_vals = f(neuralsde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()
    g_pred_vals = g(neuralsde, torch.linspace(u_truth_np.min(), u_truth_np.max(), 20), torch.linspace(ts_np.min(), ts_np.max(), 20)).detach().numpy()
    # flatten the tensors
    f_values = f_vals.flatten()
    g_values = g_vals.flatten()
    f_pred_values = f_pred_vals.flatten()
    g_pred_values = g_pred_vals.flatten()
    # compute mse_f and mse_σ
    mse_f = rel_err_f(neuralsde, sde, u_truth, temporal_OU_truth.ts)
    mse_σ = rel_err_g(neuralsde, sde, u_truth, temporal_OU_truth.ts)
    # Iterate over each batch
    ax = axs
    ax.quiver(
        t_values.numpy(), 
        u_values.numpy(),  
        g_values,
        f_values, 
        angles='xy', 
        scale_units='xy', 
        scale=1, 
        color='black', 
        width=0.002)
    ax.quiver(
        t_values.numpy(), 
        u_values.numpy(),  
        g_pred_values,
        f_pred_values, 
        angles='xy', 
        scale_units='xy', 
        scale=1, 
        color='red', 
        width=0.002)

    for i in range(u_pred_np.shape[1]):
        # Plot u_pred(t_size, batch, 0) as a function of t_size
        ax.plot(ts_np,u_pred_np[:, i, 0], color='red', alpha=0.1)
        ax.plot(ts_np,u_truth_np[:, i, 0], color='black', alpha=0.1)
    ax.set_xlabel('t')
    ax.set_ylabel('u')
    if iteration:
        ax.set_title(f'Epoch {n_iter}')
    else:
        ax.set_title(f'{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}')
    # set limits
    ax.set_xlim(ts_np.min(), ts_np.max())
    # set y limits based on u_truth
    ax.set_ylim(u_truth_np.min(), u_truth_np.max())
    if iteration:
        if not os.path.exists('tmp'):
            os.makedirs('tmp')
        if not test_flag:
            fig.savefig(f"tmp/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.png")
        # print(f"Figure saved at figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.svg.")
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), f"tmp/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.pt")
        # print(f"Model saved at tmp/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.pt.")
        return mse_f, mse_σ
    else:
        if not test_flag:
            fig.savefig(f"figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}.svg")
        # print(f"Figure saved at figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}.svg.")
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), f"models/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}.pt")


##### Setting up the experiment #####
if not is_notebook():
    # show info
    parser = argparse.ArgumentParser(description='Choose a profile.')
    # add profile argument
    parser.add_argument('--profile', 
                        type=str, 
                        required=False,
                        default='base',
                        help='The profile to use.')
    # add loss_function argument
    parser.add_argument('--loss_function',
                        type=str,
                        required=False,
                        default='W2',
                        help='The loss function to use.')
    # add repeat argument
    parser.add_argument('--repeat',
                        type=int,
                        required=False,
                        default=1,
                        help='The id of experiment repeat.')
    # add overwrite argument
    parser.add_argument('--overwrite',
                        type=bool,
                        # action='store_true',
                        required=False,
                        default=True,
                        help='Whether to overwrite the existing result.')
    # add  a test argument
    parser.add_argument('--test',
                        type = bool,
                        # action='store_true',
                        required=False,
                        default=False,
                        help='Whether to run the code in test mode.')
    args = parser.parse_args()
    profile = args.profile
    loss_function = args.loss_function
    repeat = args.repeat
    overwrite = args.overwrite
    test_flag = args.test
else:
    profile = 'base'
    loss_function = 'WGAN'
    repeat = 11
    test_flag = True
    overwrite = True

#### loss function = WGAN ####
print('Running WGAN... Overwriting loss function to WGAN.')
loss_function = 'WGAN'

# print configuration
print('Profile: ' + profile)
print('Loss Function: ' + loss_function)
print('Repeat ID: ' + str(repeat))
print('Overwrite Flag: ' + str(overwrite))


##### Handling the Ground Truth #####
if os.path.exists('profiles/temporal_OU_truth_' + profile + '.py'):
    print('Loading ground truth profile: ' + profile + ' at ' + 'profiles/temporal_OU_truth_' + profile + '.py')
    temporal_OU_truth = importlib.import_module('profiles.temporal_OU_truth_' + profile)
else:
    # load base
    print('Loading ground truth base profile at profiles/temporal_OU_truth_base.py')
    temporal_OU_truth = importlib.import_module('profiles.temporal_OU_truth_base')
from sde.temporal_OU import SDE
sde = SDE(μ=temporal_OU_truth.μ, 
            θ=temporal_OU_truth.θ, 
            σ=temporal_OU_truth.σ)


if not os.path.exists('data'):
    os.makedirs('data')
if (not os.path.exists(temporal_OU_truth.u_truth_savepath)) or overwrite:
    print('Generating data...')
    u_truth = predict(sde, 
                              temporal_OU_truth.u0, 
                              temporal_OU_truth.ts,
                              dt = 0.1
                        )
    torch.save(u_truth, temporal_OU_truth.u_truth_savepath) # shape (t_size, n_samples, 1)
    # plot the data to a square publication quality figure
    fig, ax = plt.subplots(figsize=(4, 4))
    for i in range(u_truth.shape[1]):
        ax.plot(temporal_OU_truth.ts, u_truth[:,i,0], color='black', linewidth=0.1)
    ax.set_xlabel('$t$', fontsize=14)
    ax.set_ylabel('$u(t)$', fontsize=14)
    # set x lims and y lims
    ax.set_xlim([temporal_OU_truth.ts[0], temporal_OU_truth.ts[-1]])
    # add a note on the top left: n_samples = 400
    ax.text(0.05, 0.92, '$n_{samples}$ = ' + str(temporal_OU_truth.n_samples),
        transform=ax.transAxes, fontsize=14)

    # ax.grid()
    # ax.legend(fontsize=14)
    fig.savefig("figures/temporal_OU_truth_" + temporal_OU_truth.truth_label + ".svg", dpi=300)
    fig.savefig("figures/temporal_OU_truth_" + temporal_OU_truth.truth_label + ".pdf", backend='pgf', dpi=300)
    
    print('Data generated.')
else:
    print('Loading data...')
    u_truth = torch.load(temporal_OU_truth.u_truth_savepath)
    print('Data loaded.')



# check if u_truth contains nan
if torch.isnan(u_truth).any():
    raise Exception("u_truth contains nan.")

##### Handling the Neural SDE #####
if os.path.exists('profiles/temporal_OU_nsde_' + profile + '.py'):
    print('Loading NSDE profile: ' + profile + ' at ' + 'profiles/temporal_OU_nsde_' + profile + '.py')
    temporal_OU_nsde = importlib.import_module('profiles.temporal_OU_nsde_' + profile)
else:
    # load base
    print('Loading NSDE base profile at profiles/temporal_OU_nsde_base.py')
    temporal_OU_nsde = importlib.import_module('profiles.temporal_OU_nsde_base')
# define loss
def loss(u_pred):
    t_tensor = temporal_OU_truth.ts.unsqueeze(-1).unsqueeze(-1).expand(temporal_OU_truth.t_size, temporal_OU_truth.n_samples, 1)
    u_truth_t = torch.cat((u_truth, t_tensor), dim=-1)
    t_truth_u = u_truth_t.permute(1, 0, 2)
    coeffs_truth = linear_interpolation_coeffs(t_truth_u)
    u_pred_t = torch.cat((u_pred, t_tensor), dim=-1)
    # swap the first and second dimension to make it compatible with discriminator
    t_pred_u = u_pred_t.permute(1, 0, 2)
    coeffs = linear_interpolation_coeffs(t_pred_u) # should be the same as t_pred_u
    # get the discriminator output
    generator_score = discriminator(coeffs)
    truth_score = discriminator(coeffs_truth)
    return generator_score - truth_score

   
# set torch seed according to repeat
torch.manual_seed(repeat*1000)



# set up neural SDE
neuralsde = temporal_neuralSDE(temporal_OU_nsde.state_size,
                        temporal_OU_nsde.brownian_size,
                        temporal_OU_nsde.hidden_size,
                        temporal_OU_nsde.batch_size)

# normalization of neural sde parameters according to t_range provided by temporal_OU_truth.ts[-1]
with torch.no_grad():
    for name, param in neuralsde.named_parameters():
        # print(name)
        # mu.linear1.weight
        # mu.linear1.bias
        # mu.linear2.weight
        # mu.linear2.bias
        # sigma.linear1.weight
        # sigma.linear1.bias
        # sigma.linear2.weight
        # sigma.linear2.bias
        # only normalize wrt to mu
        if name[0:2] == 'mu':
            param.mul_(1/torch.sqrt(temporal_OU_truth.ts[-1]))

##### Handling the Discriminator #####
from profiles import temporal_OU_wgan_ncde_base
from torchcde import linear_interpolation_coeffs
discriminator = Discriminator(
                temporal_OU_wgan_ncde_base.state_size,
                temporal_OU_wgan_ncde_base.hidden_size,
                temporal_OU_wgan_ncde_base.hidden_size,
                num_layers=2,
)
# averaged_discriminator = swa_utils.AveragedModel(discriminator)
with torch.no_grad():
    for name, params in discriminator.named_parameters():
        if name[-6:] == 'weight':
            torch.nn.init.normal_(params, mean=0, std=1)
# set up optimizer
optimizer = torch.optim.Adam(
    neuralsde.parameters(), 
    lr=temporal_OU_nsde.η,
    betas=temporal_OU_nsde.β,
    weight_decay=temporal_OU_nsde.weight_decay
)

model_savepath = f"models/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}.pt"
discriminator_optimiser = torch.optim.Adadelta(discriminator.parameters(), lr=temporal_OU_wgan_ncde_base.η,
                                                weight_decay=temporal_OU_nsde.weight_decay)

# check if there exists a trained model
if not os.path.exists('models'):
    os.makedirs('models')

if test_flag:
    print('Testing...')
    N_epoch = 1
else:
    # N_epoch = 4000 #temporal_OU_nsde.N_epoch
    N_epoch = temporal_OU_nsde.N_epoch

import time
import psutil
if os.path.exists(model_savepath) and not overwrite:
    print('Model exists, skip training...')
    neuralsde.load_state_dict(torch.load(model_savepath))
    print('Model loaded.')
else:
    # train the model
    print('Training...')
    ####### Train the model ########
    pbar = tqdm(range(N_epoch), desc='Training', leave=True)
    losses = []
    mse_fs = []
    mse_σs = []
    runtimes = []
    memory_usages = []
    time_0 = time.time()
    for n_iter in pbar:
        optimizer.zero_grad()
        discriminator_optimiser.zero_grad()

        u_pred = predict(neuralsde, temporal_OU_truth.u0, temporal_OU_truth.ts, dt=0.1)
        # print("get u_pred")
        current_loss = loss(u_pred)
        # print("get current_loss")
        current_loss.backward()
        for param in neuralsde.parameters():
            param.grad *= -1
        # print("stepping")
        optimizer.step()
        # print("neuralsde stepped")
        discriminator_optimiser.step()
        # print("discriminator stepped")

        # clip the discriminator weights
        with torch.no_grad():
            for module in discriminator.modules():
                if isinstance(module, torch.nn.Linear):
                    lim = 1 / module.out_features
                    module.weight.clamp_(-lim, lim)

        # Stochastic weight averaging
        # if n_iter > temporal_OU_wgan_ncde_base.swa_start:
        #     averaged_discriminator.update_parameters(discriminator)

        if n_iter % temporal_OU_nsde.checkpoint_freq == 0:
            mse_f, mse_σ = checkpoint(n_iter)
            mse_fs.append(mse_f)
            mse_σs.append(mse_σ)
            time_1 = time.time()
            runtimes.append(time_1 - time_0)
            # acquire memory usage on CPU by current process
            process = psutil.Process(os.getpid())
            memory_usage = process.memory_info().rss
            # convert to MB
            memory_usage = memory_usage / 1024 / 1024
            memory_usages.append(memory_usage)
            time_0 = time_1
            losses.append(current_loss.item()/temporal_OU_truth.t_size)
            if not test_flag:
                # save the model
                torch.save(neuralsde.state_dict(), model_savepath)

            # set postfix for tqdm
            postfix = dict(
                loss=current_loss.item(),
                mse_f=mse_f,
                mse_σ=mse_σ
            )
            # update tqdm
            pbar.set_postfix(**postfix)
            # append loss

    # save the final result
    checkpoint(N_epoch, iteration=False)

    ##### visualize and save the losses to csv
    losses = np.array(losses)
    mse_fs = np.array(mse_fs)
    mse_σs = np.array(mse_σs)
    # plot the losses
    fig, axs = plt.subplots(3, 1, figsize=(10, 15))
    ax = axs[0]
    # plot loss
    ax.plot(losses)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('loss')
    ax.set_title(f'{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}')
    # ax.set_yscale('log')
    # plot mse_f
    ax = axs[1]
    ax.plot(mse_fs)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_f')
    # ax.set_title(f'{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}')
    ax.set_yscale('log')
    # plot mse_σ
    ax = axs[2]
    ax.plot(mse_σs)
    ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_σ')
    # ax.set_title(f'{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}')

    # plt.show()
    # save the figure
    if not test_flag:
        fig.savefig(f"figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.png")
        print(f"Figure saved at figures/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.png.")
    plt.close(fig)

    # save the losses, mse_fs, mse_σs to df to csv
    df = pd.DataFrame({
        'loss': losses,
        'mse_f': mse_fs,
        'mse_σ': mse_σs,
        'runtime': runtimes,
        'memory_usage': memory_usages
    })
    if not test_flag:
        df.to_csv(f"data/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.csv", index=False, header=True)
        print(f"Data saved at data/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.csv.")
    else:
        # save to test filename
        df.to_csv(f"data/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses_test.csv", index=False, header=True)

if test_flag:
    # profiling memory usage
    with torch.profiler.profile(profile_memory=True, record_shapes=True) as prof:
        optimizer.zero_grad()
        u_pred = predict(neuralsde, temporal_OU_truth.u0, temporal_OU_truth.ts, dt=0.1)
        # print("get u_pred")
        current_loss = loss(u_pred)
        # print("get current_loss")
        current_loss.backward()
        # print("get current_loss.backward()")
        optimizer.step()

    table=prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10)
    import re
    def table_to_csv(table_string):
        # Split lines and filter out empty ones
        lines = [line for line in table_string.split("\n") if line.strip()]
        
        # Extract header
        header = re.split(r'\s{2,}', lines[1].strip())
        
        # Extract data using the positions from the header
        data_lines = lines[3:]
        data = [re.split(r'\s{2,}', line.strip()) for line in data_lines]
        
        # Construct CSV
        csv_lines = [",".join(header)]
        csv_lines.extend([",".join(line) for line in data])
        
        return "\n".join(csv_lines)

    csv_table = table_to_csv(table)
    csv_table
    import pandas as pd
    from io import StringIO

    # Use pandas to read the CSV string into a DataFrame
    prof_df = pd.read_csv(StringIO(csv_table))
    # remove 2 trailing rows
    prof_df = prof_df.iloc[:-2]
    # parse Self CPU Mem to float in Gb and Mb
    def parse_mem(mem):
        if mem[-2:] == 'Gb':
            return float(mem[:-2]) * 1024 / 8
        elif mem[-2:] == 'Mb':
            return float(mem[:-2]) / 8
        else:
            raise Exception("Unknown memory unit.")
    prof_df['Self CPU Mem'] = prof_df['Self CPU Mem'].apply(parse_mem)
    mem_use = sum(prof_df['Self CPU Mem'].values)
    df = pd.DataFrame({
        'loss': losses,
        'mse_f': mse_fs,
        'mse_σ': mse_σs,
        'runtime': runtimes,
        'memory_usage': memory_usages
    })
    # replace the memory usage column entrywise
    df['memory_usage'] = mem_use
    df.to_csv(f"data/{temporal_OU_nsde.nsde_label}_{temporal_OU_truth.truth_label}_{loss_function}_repeat_{repeat}_losses_test.csv", index=False, header=True)

