import numpy as np
import matplotlib.pyplot as plt
import itertools


def exp_tw(t, w):
    return np.exp(-w * t)


def c_solver(T, m, time_list, eigenvalues, rho):
    """
    T >= m

    :param T:
    :param m:
    :param time_list:
    :param eigenvalues:
    :param rho:
    :return:
    """
    assert len(time_list) == T
    assert len(eigenvalues) == m

    a = [[exp_tw(t, eigenvalue) for eigenvalue in eigenvalues] for t in time_list]

    c = np.linalg.lstsq(a, rho, rcond=None)[0]

    return c


def perturbation_error_evaluation(
    T, m, time_list, eigenvalues, rho, c, perturbation, parametrization=None
):
    if parametrization == "Exp":
        a_perturbed = [
            [
                exp_tw(t, eigenvalue * np.exp(-perturbation))
                for eigenvalue in eigenvalues
            ]
            for t in time_list
        ]
    elif parametrization == "Softplus":
        a_perturbed = [
            [
                exp_tw(t, np.log(1 + (np.exp(eigenvalue) - 1) * np.exp(-perturbation)))
                for eigenvalue in eigenvalues
            ]
            for t in time_list
        ]
    elif parametrization is None:
        a_perturbed = [
            [exp_tw(t, eigenvalue - perturbation) for eigenvalue in eigenvalues]
            for t in time_list
        ]
    else:
        raise NotImplementedError

    c = c * (1 + perturbation * 0.01)

    a_perturbed = np.array(a_perturbed)
    diff = a_perturbed @ c - rho
    assert diff.shape == (T,)
    error = np.max(diff)

    return error


def memory_RNN(memory_type="exp", rnn_type=None):
    # memory_type: "exp" or "pol",
    # rnn_type: None or "Exp"

    T = 100
    time_list = np.arange(0, T)
    m_list = [2, 4, 6, 8, 14, 128]

    if memory_type == "exp":
        rho = [0.9**t for t in range(T)]
    elif memory_type == "pol":
        rho = [1 / (t + 1) ** 1.1 for t in range(T)]
    else:
        raise NotImplementedError
    assert len(rho) == T

    plt.figure(figsize=(6, 4), dpi=300)
    line_styles = ["-"]
    line_styles_cycle = itertools.cycle(line_styles)

    for m, line_style in zip(m_list, line_styles_cycle):
        # m = 10
        eigenvalues = np.array([1 / np.sqrt(index + 1) for index in range(m)])

        c = c_solver(T, m, time_list, eigenvalues, rho)
        assert len(c) == m
        # print(c)

        perturbation_list = [5e-4 * 2**p for p in range(20)]
        perturbation_error_list = []
        for perturbation_iter in perturbation_list:
            error = perturbation_error_evaluation(
                T,
                m,
                time_list,
                eigenvalues,
                rho,
                c,
                perturbation_iter,
                parametrization=rnn_type,
            )
            perturbation_error_list.append(error)

        if memory_type == "pol" and rnn_type is not None:
            if m == 128:
                plt.plot(
                    perturbation_list,
                    perturbation_error_list,
                    color="black",
                    linestyle="--",
                    linewidth=2.2,
                )
            else:
                plt.plot(
                    perturbation_list,
                    perturbation_error_list,
                    label=f"m={m}",
                    linestyle=line_style,
                    linewidth=1.2 + m / 12,
                )
        else:
            if m != 128:
                plt.plot(
                    perturbation_list,
                    perturbation_error_list,
                    label=f"m={m}",
                    linestyle=line_style,
                    linewidth=1.2 + m / 12,
                )

    plt.xlabel(r"Perturbation $\beta$", fontsize=14)
    plt.ylabel(r"Perturbation error $E_m(\beta)$", fontsize=14)
    plt.xscale("log")
    plt.yscale("log")

    plt.xlim([3e-4, 1e1])
    plt.ylim([7e-5, 2e2])

    additional_plot(memory_type, rnn_type)

    plt.legend()
    plt.tight_layout()
    # plt.show()
    plt.savefig(f"./perturbation_error_{memory_type}_{rnn_type}.pdf")


def additional_plot(memory_type, rnn_type):
    if memory_type == "exp" and rnn_type is None:
        plt.axvline(
            x=0.07,
            ymin=0.01,
            ymax=0.57,
            color="r",
            linewidth=2.2,
            linestyle="--",
        )
        plt.axhline(
            y=1.97e2,
            xmin=0.52,
            xmax=1,
            color="k",
            linewidth=4,
            linestyle="-.",
        )  # E(beta)
        plt.plot([5e-4, 7e-2], [0.00017, 0.29], color="k", linestyle="-.", linewidth=4)
        plt.text(2.0e-4, 0.000023, r"$0$")
        plt.text(2.0e-4, 2e2, r"$\infty$", fontsize=14)
        plt.text(0.075, 2e-4, r"$\beta_0$", fontsize=14)
        plt.plot(3e-4, 7e-5, marker="o", color="black", linestyle="None", clip_on=False)

    if memory_type == "pol" and rnn_type is None:
        plt.axhline(
            y=1.97e2,
            xmin=0.01,
            xmax=1,
            color="k",
            linewidth=4,
            linestyle="-.",
        )  # E(beta)
        plt.arrow(
            0.11,
            5e-2,
            -0.08,
            -4e-2,
            width=0.0002,
            head_width=7e-3,
            head_length=9e-3,
            fc="red",
            ec="red",
        )
        plt.text(2e-4, 0.000023, r"$0$")
        plt.text(2e-4, 2e2, r"$\infty$", fontsize=14)
        plt.text(0.11, 5e-3, r"$\beta$")
        plt.plot(3e-4, 7e-5, marker="o", color="black", linestyle="None", clip_on=False)
        plt.plot(
            3e-4,
            1.97e2,
            marker="o",
            fillstyle="none",
            color="black",
            linestyle="None",
            clip_on=False,
        )
        # Dotted
        plt.plot(
            0.113,
            0.155,
            marker="o",
            color="black",
            linestyle="None",
            clip_on=False,
            markersize=3.2,
        )
        plt.plot(
            0.055,
            0.08,
            marker="o",
            color="black",
            linestyle="None",
            clip_on=False,
            markersize=3.2,
        )
        plt.plot(
            0.035,
            0.057,
            marker="o",
            color="black",
            linestyle="None",
            clip_on=False,
            markersize=3.2,
        )
        plt.plot(
            0.0165,
            0.022,
            marker="o",
            color="black",
            linestyle="None",
            clip_on=False,
            markersize=3.2,
        )

    if memory_type == "pol" and rnn_type in ["Exp", "Softplus"]:
        plt.text(2e-4, 0.000023, r"$0$")
        plt.text(2.0e-4, 2e2, r"$\infty$", fontsize=14)
        plt.plot(3e-4, 7e-5, marker="o", color="black", linestyle="None", clip_on=False)


if __name__ == "__main__":
    for memory_type in ["exp", "pol"]:
        for rnn_type in [None, "Exp", "Softplus"]:
            plt.clf()
            memory_RNN(memory_type=memory_type, rnn_type=rnn_type)
