import torch 
torch.set_num_threads(1)
import torchsde
import torch.optim.swa_utils as swa_utils
import argparse
import os
import IPython
from utils.wgan_ncde_utils import Discriminator

##### check if in juptyer notebook
def is_notebook():
    try:
        shell = IPython.get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter
from pathlib import Path
import matplotlib as mpl
import matplotlib.pyplot as plt
# import tqdm
from tqdm import tqdm
import importlib
from utils.sde_utils import neuralSDE, predict, f,g
import matplotlib.pyplot as plt
import pandas as pd

# set tensor type to float64
torch.set_default_dtype(torch.float64)
##### Define callback function to inspect reconstructed dynamics and trajectory
import numpy as np

def checkpoint(n_iter, iteration = True):
    u_pred = predict(neuralsde, cix_truth.u0, cix_truth.ts)
    # Convert the tensor to a numpy array
    u_pred_np = u_pred.detach().numpy()
    u_truth_np = u_truth.detach().numpy()
    ts_np = cix_truth.ts.detach().numpy()
    # Create a figure and a set of subplots
    fig, axs = plt.subplots(3, 1, figsize=(10, 15))

    # Iterate over each batch
    ax = axs[0]
    for i in range(u_pred_np.shape[1]):
        # Plot u_pred(t_size, batch, 0) as a function of t_size
        ax.plot(ts_np,u_pred_np[:, i, 0], color='red', alpha=0.1)
        ax.plot(ts_np,u_truth_np[:, i, 0], color='black', alpha=0.1)
    ax.set_xlabel('t')
    ax.set_ylabel('u')
    if iteration:
        ax.set_title(f'Epoch {n_iter}')
    else:
        ax.set_title(f'{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}')
    # set limits
    ax.set_xlim(ts_np.min(), ts_np.max())

    xs = torch.linspace(0, 10, cix_nsde.batch_size)
    fs = f(sde, xs)
    fs_pred = f(neuralsde, xs)
    xs_np = xs.detach().numpy()
    fs_np = fs.detach().numpy()
    fs_pred_np = fs_pred.detach().numpy()
    ax = axs[1]
    ax.plot(xs_np, fs_np, color='black', alpha=0.5)
    ax.plot(xs_np, fs_pred_np, color='red', alpha=0.5)
    ax.set_xlabel('x')
    ax.set_ylabel('f(x)')
    ax.set_xlim(xs_np.min(), xs_np.max())

    σs = torch.abs(g(sde, xs))
    σs_pred = torch.abs(g(neuralsde, xs))
    σs_np = np.squeeze(σs.detach().numpy())
    σs_pred_np = np.squeeze(σs_pred.detach().numpy())
    ax = axs[2]
    ax.plot(xs_np, σs_np, color='black', alpha=0.5)
    ax.plot(xs_np, σs_pred_np, color='red', alpha=0.5)
    ax.set_xlabel('x')
    ax.set_ylabel('g(x)')
    ax.set_xlim(xs_np.min(), xs_np.max())
    xs = torch.linspace(2, 6, cix_nsde.batch_size)
    fs = f(sde, xs)
    fs_pred = f(neuralsde, xs)
    σs = g(sde, xs)
    σs_pred = g(neuralsde, xs)
    xs_np = xs.detach().numpy()
    fs_np = fs.detach().numpy()
    fs_pred_np = fs_pred.detach().numpy()
    σs_np = σs.detach().numpy()
    σs_pred_np = σs_pred.detach().numpy()
    # evaluate the mse between fs_np and fs_pred_np, σs_np and σs_pred_np
    mse_f = np.mean((fs_np - fs_pred_np)**2)
    mse_σ = np.mean((σs_np - σs_pred_np)**2)
    # plt.show()
    # save the figure
    if iteration:
        if not os.path.exists('tmp'):
            os.makedirs('tmp')
        fig.savefig(f"tmp/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.png")
        # print(f"Figure saved at figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.svg.")
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), f"tmp/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.pt")
        # print(f"Model saved at tmp/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_n_iter_{n_iter}.pt.")
        return mse_f, mse_σ
    else:
        fig.savefig(f"figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.svg")
        # print(f"Figure saved at figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.svg.")
        plt.close(fig)
        # save the model
        torch.save(neuralsde.state_dict(), f"models/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.pt")

fpath = Path(mpl.get_data_path(), "fonts/ttf/cmr10.ttf")
##### Setting up the experiment #####
if not is_notebook():
    # show info
    parser = argparse.ArgumentParser(description='Choose a profile.')
    # add profile argument
    parser.add_argument('--profile', 
                        type=str, 
                        required=False,
                        default='base',
                        help='The profile to use.')
    # add loss_function argument
    parser.add_argument('--loss_function',
                        type=str,
                        required=False,
                        default='W2',
                        help='The loss function to use.')
    # add repeat argument
    parser.add_argument('--repeat',
                        type=int,
                        required=False,
                        default=1,
                        help='The id of experiment repeat.')
    # add overwrite argument
    parser.add_argument('--overwrite',
                        type=bool,
                        required=False,
                        default=False,
                        help='Whether to overwrite the existing result.')
    args = parser.parse_args()
    profile = args.profile
    loss_function = args.loss_function
    repeat = args.repeat
    overwrite = args.overwrite
else:
    profile = 'n_sample_200'
    loss_function = 'W2'
    repeat = 11
    overwrite = True

#### loss function = WGAN ####
print('Running WGAN... Overwriting loss function to WGAN.')
loss_function = 'WGAN'

# print configuration
print('Profile: ' + profile)
print('Loss Function: ' + loss_function)
print('Repeat ID: ' + str(repeat))
print('Overwrite Flag: ' + str(overwrite))


##### Handling the Ground Truth #####
if os.path.exists('profiles/cix_truth_' + profile + '.py'):
    print('Loading ground truth profile: ' + profile + ' at ' + 'profiles/cix_truth_' + profile + '.py')
    cix_truth = importlib.import_module('profiles.cix_truth_' + profile)
else:
    # load base
    print('Loading ground truth base profile at profiles/cix_truth_base.py')
    cix_truth = importlib.import_module('profiles.cix_truth_base')
from sde.cix import SDE
sde = SDE(a=cix_truth.a, 
            b=cix_truth.b, 
            σ=cix_truth.σ)


if not os.path.exists('data'):
    os.makedirs('data')
if (not os.path.exists(cix_truth.u_truth_savepath)) or overwrite:
    print('Generating data...')
    u_truth = predict(sde, 
                              cix_truth.u0, 
                              cix_truth.ts)
    torch.save(u_truth, cix_truth.u_truth_savepath) # shape (t_size, n_samples, 1)
    # plot the data to a square publication quality figure
    fig, ax = plt.subplots(figsize=(4, 4))
    for i in range(u_truth.shape[1]):
        ax.plot(cix_truth.ts, u_truth[:,i,0], color='black', linewidth=0.1)
    ax.set_xlabel('$t$', fontsize=14)
    ax.set_ylabel('$u(t)$', fontsize=14)
    # set x lims and y lims
    ax.set_xlim([0, 2])
    # add a note on the top left: n_samples = 400
    ax.text(0.05, 0.95, '$n_{samples}$ = ' + str(cix_truth.n_samples),
        transform=ax.transAxes, fontsize=14)

    # ax.grid()
    # ax.legend(fontsize=14)
    fig.savefig("figures/cix_truth_" + cix_truth.truth_label + ".svg", dpi=300)
    fig.savefig("figures/cix_truth_" + cix_truth.truth_label + ".pdf", backend='pgf', dpi=300)
    
    print('Data generated.')
else:
    print('Loading data...')
    u_truth = torch.load(cix_truth.u_truth_savepath)
    print('Data loaded.')



# check if u_truth contains nan
if torch.isnan(u_truth).any():
    raise Exception("u_truth contains nan.")

##### Handling the Neural SDE #####
if os.path.exists('profiles/cix_nsde_' + profile + '.py'):
    print('Loading NSDE profile: ' + profile + ' at ' + 'profiles/cix_nsde_' + profile + '.py')
    cix_nsde = importlib.import_module('profiles.cix_nsde_' + profile)
else:
    # load base
    print('Loading NSDE base profile at profiles/cix_nsde_base.py')
    cix_nsde = importlib.import_module('profiles.cix_nsde_base')
    
# set torch seed according to repeat
torch.manual_seed(repeat*1000)


if loss_function == 'W2_mingtao_loss':
    from utils.sde_utils import W2_mingtao_loss as loss

# set up neural SDE
neuralsde = neuralSDE(cix_nsde.state_size,
                        cix_nsde.brownian_size,
                        cix_nsde.hidden_size,
                        cix_nsde.batch_size)

##### Handling the Discriminator #####
from profiles import cix_wgan_ncde_base
from torchcde import linear_interpolation_coeffs
discriminator = Discriminator(
                cix_wgan_ncde_base.state_size,
                cix_wgan_ncde_base.hidden_size,
                cix_wgan_ncde_base.hidden_size,
                num_layers=2,
)
# averaged_discriminator = swa_utils.AveragedModel(discriminator)

# set up optimizer
optimizer = torch.optim.Adam(
    neuralsde.parameters(), 
    lr=cix_nsde.η,
    betas=cix_nsde.β,
    weight_decay=cix_nsde.weight_decay
)

model_savepath = f"models/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.pt"
discriminator_optimiser = torch.optim.Adadelta(discriminator.parameters(), lr=cix_wgan_ncde_base.η,
                                                weight_decay=cix_nsde.weight_decay)

# check if there exists a trained model
if not os.path.exists('models'):
    os.makedirs('models')

def loss(u_truth, u_pred):
    t_tensor = cix_truth.ts.unsqueeze(-1).unsqueeze(-1).expand(cix_truth.t_size, cix_truth.n_samples, 1)
    u_truth_t = torch.cat((u_truth, t_tensor), dim=-1)
    t_truth_u = u_truth_t.permute(1, 0, 2)
    coeffs_truth = linear_interpolation_coeffs(t_truth_u)
    u_pred_t = torch.cat((u_pred, t_tensor), dim=-1)
    # swap the first and second dimension to make it compatible with discriminator
    t_pred_u = u_pred_t.permute(1, 0, 2)
    coeffs = linear_interpolation_coeffs(t_pred_u) # should be the same as t_pred_u
    # get the discriminator output
    generator_score = discriminator(coeffs)
    truth_score = discriminator(coeffs_truth)
    return generator_score - truth_score


if os.path.exists(model_savepath) and not overwrite:
    print('Model exists, skip training...')
    neuralsde.load_state_dict(torch.load(model_savepath))
    print('Model loaded.')
else:
    # train the model
    print('Training...')
    ####### Train the model ########
    pbar = tqdm(range(cix_nsde.N_epoch), desc='Training', leave=True)
    losses = []
    mse_fs = []
    mse_σs = []

    for n_iter in pbar:
        
        u_pred = predict(neuralsde, cix_truth.u0, cix_truth.ts)
        current_loss = loss(u_truth, u_pred)
        # print(current_loss)
        current_loss.backward()
        for param in neuralsde.parameters():
            param.grad *= -1
        # print("stepping")
        optimizer.step()
        # print("neuralsde stepped")
        discriminator_optimiser.step()
        # print("discriminator stepped")

        optimizer.zero_grad()
        discriminator_optimiser.zero_grad()

        # clip the discriminator weights
        with torch.no_grad():
            for module in discriminator.modules():
                if isinstance(module, torch.nn.Linear):
                    lim = 1 / module.out_features
                    module.weight.clamp_(-lim, lim)

        # Stochastic weight averaging
        # if n_iter > cix_wgan_ncde_base.swa_start:
        #     averaged_discriminator.update_parameters(discriminator)

        if n_iter % cix_nsde.checkpoint_freq == 0:
            mse_f, mse_σ = checkpoint(n_iter)
            mse_fs.append(mse_f)
            mse_σs.append(mse_σ)
            losses.append(current_loss.item()/cix_truth.t_size)
            # torch.save(neuralsde.state_dict(), model_savepath)
            # set postfix for tqdm
            postfix = dict(
                loss=current_loss.item(),
                mse_f=mse_f,
                mse_σ=mse_σ
            )
            # update tqdm
            pbar.set_postfix(**postfix)
            # append loss

    # save the final result
    checkpoint(cix_nsde.N_epoch, iteration=False)

    ##### visualize and save the losses to csv
    losses = np.array(losses)
    mse_fs = np.array(mse_fs)
    mse_σs = np.array(mse_σs)
    # plot the losses
    fig, axs = plt.subplots(3, 1, figsize=(10, 15))
    ax = axs[0]
    # plot loss
    ax.plot(losses)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('loss')
    ax.set_title(f'{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}')
    ax.set_yscale('log')
    # plot mse_f
    ax = axs[1]
    ax.plot(mse_fs)
    # ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_f')
    # ax.set_title(f'{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}')
    ax.set_yscale('log')
    # plot mse_σ
    ax = axs[2]
    ax.plot(mse_σs)
    ax.set_xlabel('n_iter')
    ax.set_ylabel('mse_σ')
    # ax.set_title(f'{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}')

    # plt.show()
    # save the figure
    fig.savefig(f"figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.png")
    print(f"Figure saved at figures/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.png.")
    plt.close(fig)

    # save the losses, mse_fs, mse_σs to df to csv
    df = pd.DataFrame({
        'loss': losses,
        'mse_f': mse_fs,
        'mse_σ': mse_σs
    })
    df.to_csv(f"data/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_losses.csv", index=False, header=True)



####### Model Testing ########
if os.path.exists(f"profiles/cix_test_{profile}.py"):
    print(f"Loading testing profile {profile}...")
    cix_test = importlib.import_module('profiles.cix_test_' + profile)
else:
    # load the base profile
    print(f"Loading testing base profile at profiles/cix_test_base.py...")
    cix_test = importlib.import_module('profiles.cix_test_base')
# expecting to have different n_samples collected as N_samples
# and different u0 corresponding to each n_sample as u0s
N_samples = cix_test.N_samples
u0s = cix_test.u0s
# load the model
neuralsde.load_state_dict(torch.load(f"models/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}.pt"))

# test_truth
u_tests = [
    predict(sde, u0s[i], cix_test.ts) for i in range(len(N_samples))
]
# test_pred
u_preds = [
    predict(neuralsde, u0s[i], cix_test.ts) for i in range(len(N_samples))
]
# test_errors
test_errors = [
    # return nan
    np.nan for i in range(len(N_samples))
]

xs = torch.linspace(2, 6, 100)
# mse_f 
fs = f(sde, xs).detach().numpy()
fs_pred = f(neuralsde, xs).detach().numpy()
mse_f = np.mean((fs - fs_pred)**2)
# mse_σ
σs = g(sde, xs).detach().numpy()
σs_pred = g(neuralsde, xs).detach().numpy()
mse_σ = np.mean((σs - σs_pred)**2)
# re-evaulate the train_error
train_error = loss(u_truth,predict(neuralsde, cix_truth.u0, cix_truth.ts)).detach().numpy()
# save as dataframe, each row: cix_truth.n_samples, test n_samples, test_error, loss_function
df = pd.DataFrame({
    'id': [repeat] * len(N_samples),
    'train_n_samples': [cix_truth.n_samples] * len(N_samples),
    'test_n_samples': N_samples,
    'test_error': test_errors,
    'train_error': [train_error] * len(N_samples),
    'mse_f': [mse_f] * len(N_samples),
    'mse_σ': [mse_σ] * len(N_samples),
    'loss_function': [loss_function] * len(N_samples)
})
# save the dataframe to  data/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_test_errors.csv
df.to_csv(f"data/{cix_nsde.nsde_label}_{cix_truth.truth_label}_{loss_function}_repeat_{repeat}_test_errors.csv", index=False)