import matplotlib.pyplot as plt
import json
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from mpl_sizes import get_format


def vis1():
    """
    Extrapolate sequence length for testing
    """
    plt.rcParams["text.usetex"] = True
    formatter = get_format("ICLR")  # options: ICLR, ICML, NeurIPS, InfThesis
    plt.rcParams["text.usetex"] = True
    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = ["Times"]
    # colors = ['#03045e', '#033e8a', '#0077b6', '#0296c8', '#06b4d8', '#49cae4']
    colors = ["#EB8531", "#30CEEA"]

    SAVE_PATH = "./"
    DATA_PATH = "/home/user/latte_trans/tests/ideas/data/"

    with open(DATA_PATH + "scores_att_new.json") as f:
        att = json.load(f)

    with open(DATA_PATH + "scores_mach2.json") as f:
        mach = json.load(f)

    start_window = 128
    cap = 50
    SAVE_PATH = "./"
    figsize = (formatter.text_width_plot()[0], formatter.text_width_plot()[1])
    plt.figure(figsize=figsize)
    ppl = np.array(att["PPL_mean"])
    ppl = np.where(ppl > cap, cap, ppl)
    # ppl = np.log(ppl)
    plt.plot(
        start_window + np.arange(len(ppl[0, start_window:])),
        ppl[0, start_window:],
        label="Standard Causal Attention",
        color=colors[0],
    )
    ppl = np.array(mach["PPL_mean"])
    ppl = np.where(ppl > cap, cap, ppl)
    # ppl = np.log(ppl)
    plt.plot(
        start_window + np.arange(len(ppl[0, start_window:])),
        ppl[0, start_window:],
        label="Latte-RGLRU-SWA++",
        color=colors[1],
    )
    plt.xlabel("\# Tokens")
    plt.ylabel("Eval PPL")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{SAVE_PATH}/seq_len_ext.pdf")


def vis_bar2B():
    """2B model bar time performance"""
    plt.rcParams["text.usetex"] = True
    formatter = get_format("ICLR")  # options: ICLR, ICML, NeurIPS, InfThesis
    plt.rcParams["text.usetex"] = True
    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = ["Times"]
    # colors = ['#03045e', '#033e8a', '#0077b6', '#0296c8', '#06b4d8', '#49cae4']
    colors = ["#e39657ff", "#65a66fff", "#4ba6b6ff"]

    # width of the bars
    barWidth = 0.3

    # Choose the height of the blue bars
    bars_latte_rglru = [1352.75, 1364.89, 1372.72, 1391.95]

    # Choose the height of the cyan bars
    bars_att = [1221.13, 1529.92, 2194.32, 3619.99]
    bars_latte_conv = [1333.45, 1321.52, 1310.54, 1280.68]

    # Choose the height of the error bars (bars_latte_rglru)
    yer1 = [36.68, 3.03, 3.05, 3.34]

    # Choose the height of the error bars (bars_att)
    yer2 = [252.08, 316.26, 454.27, 750.0]
    #
    yer3 = [275.35, 273.12, 270.93, 264.68]
    # The x position of bars
    r1 = np.arange(len(bars_latte_rglru))
    r2 = [x + barWidth for x in r1]
    r3 = [x + barWidth for x in r2]

    figsize = (formatter.text_width_plot()[0], formatter.text_width_plot()[1])
    plt.figure(figsize=figsize)

    plt.bar(
        r1,
        bars_latte_conv,
        width=barWidth,
        color=colors[1],
        edgecolor="black",
        yerr=yer3,
        capsize=7,
        label="Latte-Conv-SWA++",
    )

    # Create blue bars
    plt.bar(
        r2,
        bars_latte_rglru,
        width=barWidth,
        color=colors[2],
        edgecolor="black",
        yerr=yer1,
        capsize=7,
        label="Latte-RGLRU-SWA++",
    )

    plt.bar(
        r3,
        bars_att,
        width=barWidth,
        color=colors[0],
        edgecolor="black",
        yerr=yer2,
        capsize=7,
        label="Standard Causal Attention",
    )

    # general layout
    plt.xticks(
        [r + barWidth for r in range(len(bars_latte_rglru))], ["4K", "8K", "16K", "32K"]
    )
    plt.ylabel("Forward Pass Time (ms)")
    plt.xlabel("Sequence Length")
    plt.legend()

    plt.tight_layout()
    plt.savefig(f"./bar_plot2-6B.pdf")


def vis_bar400M():
    """400M model bar"""
    plt.rcParams["text.usetex"] = True
    formatter = get_format("ICLR")  # options: ICLR, ICML, NeurIPS, InfThesis
    plt.rcParams["text.usetex"] = True
    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = ["Times"]
    # colors = ["#03045e", "#033e8a", "#0077b6", "#0296c8", "#06b4d8", "#49cae4"]
    # colors = ["#EB8531", "#30CEEA", "#4db1c3ff"]
    colors = ["#e39657ff", "#65a66fff", "#4ba6b6ff"]

    # width of the bars
    barWidth = 0.3

    # Choose the height of the blue bars
    bars_latte_rglru = [416.83, 414.81, 450.37, 484.80]

    # Choose the height of the cyan bars
    bars_att = [468.45, 681.38, 1170.40, 2191.90]

    bars_latte_conv = [396.87, 381.49, 398.35, 390.57]
    # Choose the height of the error bars (bars_latte_rglru)
    yer1 = [24.87, 0.82, 0.74, 0.69]

    # Choose the height of the error bars (bars_att)
    yer2 = [96.58, 140.36, 241.92, 453.83]
    yer3 = [81.60, 78.12, 81.0, 80.10]
    # The x position of bars
    r1 = np.arange(len(bars_latte_rglru))
    r2 = [x + barWidth for x in r1]
    r3 = [x + barWidth for x in r2]

    figsize = (formatter.text_width_plot()[0], formatter.text_width_plot()[1])
    plt.figure(figsize=figsize)
    # Create cyan bars
    plt.bar(
        r1,
        bars_latte_conv,
        width=barWidth,
        color=colors[1],
        edgecolor="black",
        yerr=yer3,
        capsize=7,
        label="Latte-Conv-SWA++",
    )

    # Create blue bars
    plt.bar(
        r2,
        bars_latte_rglru,
        width=barWidth,
        color=colors[2],
        edgecolor="black",
        yerr=yer1,
        capsize=7,
        label="Latte-RGLRU-SWA++",
    )

    plt.bar(
        r3,
        bars_att,
        width=barWidth,
        color=colors[0],
        edgecolor="black",
        yerr=yer2,
        capsize=7,
        label="Standard Causal Attention",
    )

    # general layout
    plt.xticks(
        [r + barWidth for r in range(len(bars_latte_rglru))], ["4K", "8K", "16K", "32K"]
    )
    plt.ylabel("Forward Pass Time (ms)")
    plt.xlabel("Sequence Length")
    plt.legend()

    plt.tight_layout()
    plt.savefig(f"./bar_plot400M.pdf")


if __name__ == "__main__":
    vis1()
    vis_bar2B()
    vis_bar400M()
