import time
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.callbacks import TerminateOnNaN, ReduceLROnPlateau

from tqdm.keras import TqdmCallback

from callback import StopAtCertainError
from utils import reset_tf


def generate(configs):
    inputs = configs["input_scale"] * np.random.normal(
        size=(
            configs["size"],
            configs["path_len"],
            configs["input_dim"],
        )
    )

    outputs = []
    for t in range(configs["path_len"]):
        output = 0
        for s in range(t + 1):
            output += (
                inputs[:, t - s, :]
                * np.power(np.abs(inputs[:, t - s, :]), configs["power"] - 1)
                * (configs["rho"](s * configs["dt"]))
            )
        output = output * configs["dt"]
        outputs.append(output)
    return inputs, np.asarray(outputs).transpose(1, 0, 2)


def dataset_construction(configs):
    generator_config = {
        "input_scale": 0.1,
        "size": configs["batch_size"] * 100,
        "path_len": configs["test_time"],
        "input_dim": configs["input_dim"],
        "rho": configs["rho"],
        "dt": configs["dt"],
        "power": configs["power"],
    }

    x_train, y_train = generate(generator_config)
    x_test, y_test = generate(generator_config)

    return x_train, y_train, x_test, y_test


def derivative_error(y_pred, y_true, dt):
    derivative_y_pred = (y_pred[:, 1:] - y_pred[:, :-1]) / dt
    derivative_y_true = (y_true[:, 1:] - y_true[:, :-1]) / dt

    return np.max(np.abs(derivative_y_pred - derivative_y_true))


def model_construction(fit_time, m_iter, activation, epochs):
    model = keras.Sequential()
    model.add(keras.Input(shape=(fit_time, 1)))

    model.add(
        layers.SimpleRNN(
            m_iter,
            activation=activation,
            recurrent_initializer="zero",
            return_sequences=True,
        )
    )

    model.add(layers.Dense(1))

    callbacks = [
        TqdmCallback(verbose=0),
        TerminateOnNaN(),
        StopAtCertainError(1e-8),
        ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, min_lr=0.001),
    ]
    model.compile(optimizer=Adam(5e-3), loss="mse")

    return model, callbacks


def plot_sample_error(
    y_pred_long,
    y_test,
    index,
    save_dir_with_m,
    activation,
    memory_type,
    fit_time,
    functional_type,
):
    fig, ax = plt.subplots(2, 2, figsize=(20, 16))
    ax[0][0].plot(y_pred_long)
    ax[0][0].plot(np.squeeze(y_test[index : index + 1, :]))
    ax[0][0].set_title(r"$H_t, \hat{H}_t$")

    ax[0][1].plot(np.abs(y_pred_long - np.squeeze(y_test[0:1, :])))
    ax[0][1].set_title(r"$|H_t - \hat{H}_t|$")

    ax[1][0].plot(y_pred_long[1:] - y_pred_long[:-1])
    ax[1][0].plot(
        np.squeeze(y_test[index : index + 1, 1:] - y_test[index : index + 1, :-1])
    )
    ax[1][0].set_title(r"$\frac{dH_t}{dt}, \frac{d\hat{H}_t}{dt}$")

    ax[1][1].plot(
        np.abs(
            y_pred_long[1:]
            - y_pred_long[:-1]
            - np.squeeze(y_test[index : index + 1, 1:] - y_test[index : index + 1, :-1])
        )
    )
    ax[1][1].set_title(r"$|\frac{dH_t}{dt} - \frac{d\hat{H}_t}{dt}|$")

    fig.savefig(
        f"{save_dir_with_m}/sample_{activation}_{memory_type}_T={fit_time}_{functional_type}.pdf"
    )


def evaluate_perturbation_error(
    perturbation_list, model, num_repeats, x_test, y_test, dt
):
    error_list = []
    for perturbation_scale in perturbation_list:
        error_list_tmp = []
        weights = list(model.get_weights())
        for _ in range(num_repeats):
            perturbed_weights = [
                weight + perturbation_scale * np.random.randn(*(weight.shape))
                for weight in weights
            ]
            model.set_weights(perturbed_weights)

            error = derivative_error(
                model.layers[1](model.layers[0](x_test[:, :])), y_test, dt
            )

            error_list_tmp.append(error)
        error_list.append(np.mean(error_list_tmp))

    return error_list


def plot_perturbation_error(
    m_series,
    perturbation_list,
    updated_save_dir,
    activation,
    memory_type,
    fit_time,
    functional_type,
    df,
):
    fig, ax = plt.subplots(1, 1, figsize=(10, 8))

    ylim_list = [1e-3, 1e5]

    for m_iter in m_series:
        ax.plot(perturbation_list, df[f"m={m_iter}"])
    ax.set_xlabel(r"perturbation $\omega$")
    ax.set_ylabel(r"Perturbation error $E(\omega)$")
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_ylim(ylim_list)
    ax.legend(df)

    ax.set_title("")
    fig.savefig(
        f"{updated_save_dir}/error_{activation}_{memory_type}_T={fit_time}_{functional_type}.pdf"
    )
    df.to_csv(
        f"{updated_save_dir}/error_{activation}_{memory_type}_T={fit_time}_{functional_type}.csv"
    )


def prepare_experiments(memory_type, save_dir):
    def memory_exp(x, sum_terms=20):
        return (
            np.sum(np.exp([(-2 - np.sqrt(i)) * x for i in range(sum_terms)]))
            / sum_terms
        )

    epochs = 1000
    batch_size = 128

    fit_time = 32
    test_time = 100
    m_series = [4, 8, 16, 64, 256]

    activation = "tanh"

    dt = 0.1
    num_repeats = 3
    bool_linear_functional = False
    if bool_linear_functional:
        functional_type = "LF"
    else:
        functional_type = "NLF"

    perturbation_list = [0] + [1e-11 * 2**power for power in range(0, 36)]

    memory_pol = lambda x: 1 / (x + 1) ** 1.5

    dataset_configs = {
        "batch_size": batch_size,
        "test_time": test_time,
        "input_dim": 1,
        "rho": memory_pol if memory_type == "pol" else memory_exp,
        "dt": dt,
        "power": 1 if bool_linear_functional else 1.01,
    }

    x_train, y_train, x_test, y_test = dataset_construction(dataset_configs)

    df = pd.DataFrame({})

    updated_save_dir = save_dir.joinpath(
        f"{activation}_{memory_type}_fit_time_T_{fit_time}"
    )
    updated_save_dir.mkdir(exist_ok=True)
    for m_iter in m_series:
        reset_tf()
        save_dir_with_m = updated_save_dir.joinpath(f"m={m_iter}")
        save_dir_with_m.mkdir(exist_ok=True)

        model, callbacks = model_construction(fit_time, m_iter, activation, epochs)

        history = model.fit(
            x=x_train[:, :fit_time],
            y=y_train[:, :fit_time],
            validation_data=(x_test[:, :fit_time], y_test[:, :fit_time]),
            epochs=epochs,
            batch_size=128,
            verbose=0,
            callbacks=callbacks,
        )

        results = pd.DataFrame(history.history)
        results["epoch"] = history.epoch
        model.save(f"{save_dir_with_m}/model.h5")
        results.to_csv(
            f"{save_dir_with_m}/history_{activation}_{memory_type}_T={fit_time}_{functional_type}.csv"
        )

        index = 0
        y_pred_long = np.squeeze(
            model.layers[1](model.layers[0](x_test[index : index + 1, :]))
        )
        plot_sample_error(
            y_pred_long,
            y_test,
            index,
            save_dir_with_m,
            activation,
            memory_type,
            fit_time,
            functional_type,
        )

        df[f"m={m_iter}"] = evaluate_perturbation_error(
            perturbation_list, model, num_repeats, x_test, y_test, dt
        )

        del model

    plot_perturbation_error(
        m_series,
        perturbation_list,
        updated_save_dir,
        activation,
        memory_type,
        fit_time,
        functional_type,
        df,
    )


if __name__ == "__main__":
    timestr = time.strftime("%Y%m%d-%H%M%S")

    save_dir = Path(f"figure_2_{timestr}")
    save_dir.mkdir(exist_ok=True)

    code_save_dir = save_dir.joinpath(f"code_back_up")
    code_save_dir.mkdir(exist_ok=True)
    shutil.copy("paper_fig_2.py", code_save_dir)
    shutil.copy("callback.py", code_save_dir)
    shutil.copy("utils.py", code_save_dir)

    prepare_experiments("pol", save_dir)
    prepare_experiments("exp", save_dir)
