import matplotlib.pyplot as plt
from mpl_sizes import get_format

plt.rcParams["text.usetex"] = True
formatter = get_format("ICLR")  # options: ICLR, ICML, NeurIPS, InfThesis
plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times"]
# colors = ['#03045e', '#033e8a', '#0077b6', '#0296c8', '#06b4d8', '#49cae4']
colors = ["#EB8531", "#30CEEA"]


def plot_latents():
    """See how performance increases with the number of latent variables"""
    eval_acc = [
        42.23,
        37.70,
        36.95,
        35.06,
        34.58,
        32.67,
        31.80,
    ]  # [42.23, 40.99, 39.06,
    latents = [4, 32, 64, 128, 256, 512, 1024]  # [4, 8, 16,
    att_ppl = 28.70
    # plt.figure(figsize=(9, 4))
    figsize = (formatter.text_width_plot()[0], formatter.text_width_plot()[1])
    plt.figure(figsize=figsize)

    plt.plot(latents, eval_acc, "--bo", label="Latte-RGLRU++")  # , color=colors[1])
    # plt.axhline(y=att_ppl, color="r", linestyle="-", label="Standard Causal Attention")
    plt.plot([1024], [att_ppl], "--rx", label="Standard Causal Attention")
    plt.ylabel("Eval PPL")
    plt.xlabel("L")
    plt.xticks(latents)  # , rotation=45)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"./latent_ablation.pdf")

    # plt.savefig("./latent_ablation.png")


def seq_len():
    """
    Gradually increase sequence length and decrease batch size for attention. Keep latte fixed.
    """
    plt.rcParams["text.usetex"] = True
    formatter = get_format("ICLR")  # options: ICLR, ICML, NeurIPS, InfThesis
    plt.rcParams["text.usetex"] = True
    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = ["Times"]
    # colors = ['#03045e', '#033e8a', '#0077b6', '#0296c8', '#06b4d8', '#49cae4']
    colors = ["#EB8531", "#30CEEA"]

    latte_ppl = 18.59
    att_scores = [21.15, 19.84, 19.10, 18.60]
    seq_lens = [128, 256, 512, 1024]
    figsize = (formatter.text_width_plot()[0], formatter.text_width_plot()[1])
    plt.figure(figsize=figsize)

    plt.plot(
        seq_lens, att_scores, "--ro", label="Standard Causal Attention"
    )  # , color=colors[1])
    # plt.axhline(y=latte_ppl, color="b", linestyle="-", label="Latte-RGLRU-SWA++")
    plt.plot([1024], [latte_ppl], "--bx", label="Latte-RGLRU-SWA++")
    plt.ylabel("Eval PPL")
    plt.xlabel("Sequence Length")
    plt.xticks(seq_lens)  # , rotation=45)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"./seq_len.pdf")


if __name__ == "__main__":
    plot_latents()
    seq_len()
