import argparse
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import gmean

codenames = {
    "ars": "autoregressive",
    "sps": "speculative",
    "usps": "upper-bound speculative",
    "dsps": "state-based speculative",
    "dhsps": "history-based speculative",
    "pp": "RL-based speculative",
}
plt.rcParams.update({"font.size": 16})


def get_df(mode, dataset, draft_model, target_model, gpu, temperature, gamma):
    temperature_str = str(temperature).replace(".", "_")
    if mode == "ars":
        # in this case we only focus on the target model
        # /jet/home/bpark1/llm-inference/figures/logs/ars_openai_humaneval_llama-160m_NVIDIA_A40_t0_0_g1.csv
        # TODO: this shoudl be changed to the target model once this is done
        logfile = f"{mode}_{dataset}_{target_model}_{gpu}_t{temperature_str}.csv"

    else:
        logfile = f"{mode}_{dataset}_d{draft_model}_t{target_model}_{gpu}_t{temperature_str}_g{gamma}.csv"

    df = pd.read_csv(f"logs/{logfile}")
    assert df.shape[0] == 25, f"Error: {logfile} is not length 25"
    return df


def get_pretty_names(modes):
    return [codenames[mode] for mode in modes]


def get_token_acceptance_stats(
    mode, dataset, draft_model, target_model, gpu, temperature, gamma
):
    temperature_str = str(temperature).replace(".", "_")
    df = get_df(
        mode,
        dataset,
        draft_model,
        target_model,
        gpu,
        temperature,
        gamma,
    )
    results = {}
    results["acceptance_rate"] = df["accepted_count"] / df["draft_sample_count"]
    for column in df.columns:
        # results[f"{column}_mean"] = df[column].mean()
        results[f"{column}_mean"] = gmean(df[column])
        results[f"{column}_std"] = df[column].std()

    return results


def get_results(modes, gammas, dataset, gpu, temperature, draft_model, target_model):
    results = {}
    for mode in modes:
        acceptance_rates_avg = []
        acceptance_rates_std = []

        tok_s_avg = []
        tok_s_std = []
        gen_s_avg = []
        gen_s_std = []
        for gamma in gammas:
            ret_handle = get_token_acceptance_stats(
                mode,
                dataset,
                draft_model,
                target_model,
                gpu,
                temperature,
                gamma,
            )
            tok_s_avg.append(ret_handle["total_tok_per_sec_mean"])
            tok_s_std.append(ret_handle["total_tok_per_sec_std"])
            gen_s_avg.append(ret_handle["generate_tok_per_sec_mean"])
            gen_s_std.append(ret_handle["generate_tok_per_sec_std"])
            acceptance_rates_avg.append(ret_handle["acceptance_rate_mean"])
            acceptance_rates_std.append(ret_handle["acceptance_rate_std"])
        stats = {
            "tok_s_avg": np.array(tok_s_avg),
            "tok_s_std": np.array(tok_s_std),
            "gen_s_avg": np.array(gen_s_avg),
            "gen_s_std": np.array(gen_s_std),
            "acceptance_rate_mean": np.array(acceptance_rates_avg),
            "acceptance_rate_std": np.array(acceptance_rates_std),
        }
        results[mode] = stats
    return results


def plot(gpu, temperature, draft_model, target_model, dataset):
    gammas = [1, 2, 3, 4, 5, 6, 7, 8]
    modes = ["dsps", "sps", "usps", "ars", "dhsps", "pp"]

    # x-axis is gamma, y axis is acceptance rate
    results = {}

    for mode in modes:
        acceptance_rates_avg = []
        acceptance_rates_std = []

        tok_s_avg = []
        tok_s_std = []
        gen_s_avg = []
        gen_s_std = []
        for gamma in gammas:
            ret_handle = get_token_acceptance_stats(
                mode,
                dataset,
                draft_model,
                target_model,
                gpu,
                temperature,
                gamma,
            )
            tok_s_avg.append(ret_handle["total_tok_per_sec_mean"])
            tok_s_std.append(ret_handle["total_tok_per_sec_std"])
            gen_s_avg.append(ret_handle["generate_tok_per_sec_mean"])
            gen_s_std.append(ret_handle["generate_tok_per_sec_std"])
            acceptance_rates_avg.append(ret_handle["acceptance_rate_mean"])
            acceptance_rates_std.append(ret_handle["acceptance_rate_std"])
        stats = {
            "tok_s_avg": np.array(tok_s_avg),
            "tok_s_std": np.array(tok_s_std),
            "gen_s_avg": np.array(gen_s_avg),
            "gen_s_std": np.array(gen_s_std),
            "acceptance_rate_mean": np.array(acceptance_rates_avg),
            "acceptance_rate_std": np.array(acceptance_rates_std),
        }
        results[mode] = stats

    # Now plot the latency in tok/s

    labels = ["ars", "sps", "usps", "dsps", "dhsps", "pp"]
    colors = ["blue", "green", "purple", "red", "orange", "black"]
    linestyles = ["solid", "solid", "solid", "solid", "solid", "dashed"]

    plt.figure(figsize=(10, 6))

    for label, color, linestyle in zip(labels, colors, linestyles):
        plt.plot(
            gammas, results[label]["tok_s_avg"], label=label, marker="o", color=color
        )
        plt.plot(
            gammas,
            results[label]["gen_s_avg"],
            label=f"{label} (gen)",
            marker="x",
            color=color,
            linestyle="dashed",
        )

        # Uncomment the following lines if you want to include the standard deviation in your plot
        # plt.fill_between(
        #     gammas,
        #     results[label]["tok_s_avg"] - results[label]["tok_s_std"],
        #     results[label]["tok_s_avg"] + results[label]["tok_s_std"],
        #     alpha=0.2,
        #     color=color,
        # )
        # plt.fill_between(
        #     gammas,
        #     results[label]["gen_s_avg"] - results[label]["gen_s_std"],
        #     results[label]["gen_s_avg"] + results[label]["gen_s_std"],
        #     alpha=0.2,
        #     color=color,
        # )

    plt.xlabel("Gamma")
    plt.ylabel("Latency (Tok/s)")
    plt.legend()
    plt.title(f"Dataset: {dataset} Temp: {temperature} Latency on {gpu}")
    plt.savefig(
        f"figures/{dataset}_d{draft_model}_t{target_model}_{gpu}_t{temperature}_latency.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.clf()
    matplotlib.pyplot.close()

    labels = ["ars", "sps", "usps", "dsps", "dhsps", "pp"]
    colors = ["blue", "green", "purple", "red", "orange", "black"]
    linestyles = ["solid", "solid", "solid", "solid", "solid", "dashed"]

    plt.figure(figsize=(10, 6))

    for label, color, linestyle in zip(labels, colors, linestyles):
        # plt.plot(gammas, results[label]["tok_s_avg"], label=label, marker="o", color=color)
        plt.plot(
            gammas,
            results[label]["gen_s_avg"],
            label=f"{label} (gen)",
            marker="x",
            color=color,
            linestyle="dashed",
        )

        # Uncomment the following lines if you want to include the standard deviation in your plot
        # plt.fill_between(
        #     gammas,
        #     results[label]["tok_s_avg"] - results[label]["tok_s_std"],
        #     results[label]["tok_s_avg"] + results[label]["tok_s_std"],
        #     alpha=0.2,
        #     color=color,
        # )
        # plt.fill_between(
        #     gammas,
        #     results[label]["gen_s_avg"] - results[label]["gen_s_std"],
        #     results[label]["gen_s_avg"] + results[label]["gen_s_std"],
        #     alpha=0.2,
        #     color=color,
        # )

    plt.xlabel("Gamma")
    plt.ylabel("Acceptance Rate")
    plt.legend()
    plt.title(f"Dataset: {dataset} Temp: {temperature} Acceptance Rate on {gpu}")
    plt.savefig(
        f"figures/{dataset}_d{draft_model}_t{target_model}_{gpu}_t{temperature}_acceptance.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.clf()
    matplotlib.pyplot.close()


def plot_all(gpu, temperature, draft_model, target_model):
    datasets = ["gsm8k", "finance-alpaca", "xsum", "openai_humaneval"]
    fig, axs = plt.subplots(1, 4, figsize=(16, 6))
    axs = axs.ravel()
    gammas = [1, 2, 3, 4, 5, 6, 7, 8]
    modes = ["ars", "sps", "usps", "dsps", "dhsps", "pp"]

    colors = ["blue", "green", "purple", "red", "orange", "black"]
    linestyles = ["solid", "solid", "solid", "solid", "solid", "dashed"]
    lines = []

    for i, dataset in enumerate(datasets):
        results = get_results(
            modes, gammas, dataset, gpu, temperature, draft_model, target_model
        )

        for label, color, linestyle in zip(modes, colors, linestyles):
            (line,) = axs[i].plot(
                gammas,
                results[label]["tok_s_avg"],
                label=label,
                marker="o",
                color=color,
            )
            axs[i].plot(
                gammas,
                results[label]["gen_s_avg"],
                label=f"{label} (gen)",
                marker="x",
                color=color,
                linestyle="dashed",
            )
            print(
                f"{gpu} for {draft_model} to {target_model} on {dataset}: {label} {max(results[label]['tok_s_avg'])}, gamma: {np.argmax(results[label]['tok_s_avg'])+1}"
            )

            if i == 0:
                lines.append(line)
                axs[i].set_ylabel("Latency (Tok/s)")

            axs[i].set_xlabel("Gamma ($\gamma$)")
            axs[i].set_title(f"Dataset: {dataset}")
            axs[i].grid(True, axis="y")

    for i, dataset in enumerate(datasets):
        for label, color, linestyle in zip(modes, colors, linestyles):
            print(
                f"GEN {gpu} for {draft_model} to {target_model} on {dataset}: {label} {max(results[label]['gen_s_avg'])}, gamma: {np.argmax(results[label]['gen_s_avg'])+1}"
            )

    """
    print out these four values for each dataset
    * best method for (dsps, dhsps, pp)
    * best gamma for (dsps, dhsps, pp)
    * speedup over SPS with best respective gamma for SPS
    * speedup over ARS
    """
    for i, dataset in enumerate(datasets):
        results = get_results(
            modes, gammas, dataset, gpu, temperature, draft_model, target_model
        )
        for label in ["dsps", "dhsps", "pp"]:
            best_gamma = np.argmax(results[label]["gen_s_avg"]) + 1
            best_speedup = (
                results[label]["gen_s_avg"][best_gamma - 1]
                / results["sps"]["gen_s_avg"][best_gamma - 1]
            )
            print(
                f"{gpu} for {draft_model} to {target_model} on {dataset}: {label} {max(results[label]['gen_s_avg'])}, gamma: {best_gamma}, speedup: {best_speedup}"
            )

    plt.suptitle(f"Latency on {gpu}")
    plt.tight_layout()
    plt.figlegend(
        lines,
        get_pretty_names(modes),
        loc="upper center",
        bbox_to_anchor=(0.5, -0.05),
        ncol=len(modes) / 2,
    )
    plt.savefig(
        f"figures/all_d{draft_model}_t{target_model}_{gpu}_t{temperature}_latency_all.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.clf()
    matplotlib.pyplot.close()

    ########## SECOND PLOTS
    lines = []
    fig, axs = plt.subplots(1, 4, figsize=(16, 6), sharey=True)
    all_values = []

    for i, dataset in enumerate(datasets):
        results = get_results(
            modes, gammas, dataset, gpu, temperature, draft_model, target_model
        )

        for label, color, linestyle in zip(modes, colors, linestyles):
            (line,) = axs[i].plot(
                gammas,
                results[label]["gen_s_avg"],
                label=f"{label} (gen)",
                marker="o",
                color=color,
            )
            all_values.extend(results[label]["gen_s_avg"])
            if i == 0:
                lines.append(line)
                axs[i].set_ylabel("Latency (Tok/s)")

            axs[i].set_xlabel("Gamma ($\gamma$)")
            axs[i].set_title(f"Dataset: {dataset}")
            axs[i].grid(True, axis="y")

    y_min, y_max = min(all_values), max(all_values)
    for ax in axs:
        ax.set_ylim(y_min, y_max)

    plt.suptitle(f"Latency on {gpu} (Generation Only)")
    plt.tight_layout()
    plt.figlegend(
        lines,
        get_pretty_names(modes),
        loc="upper center",
        bbox_to_anchor=(0.5, -0.05),
        ncol=len(modes) / 2,
    )
    plt.savefig(
        f"figures/all_d{draft_model}_t{target_model}_{gpu}_t{temperature}_latency_gen_all.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.clf()
    matplotlib.pyplot.close()

    lines = []
    fig, axs = plt.subplots(1, 4, figsize=(16, 6))
    for i, dataset in enumerate(datasets):
        results = get_results(
            modes, gammas, dataset, gpu, temperature, draft_model, target_model
        )

        for label, color, linestyle in zip(modes, colors, linestyles):
            (line,) = axs[i].plot(
                gammas,
                results[label]["tok_s_avg"],
                label=label,
                marker="o",
                color=color,
            )
            if i == 0:
                lines.append(line)
                axs[i].set_ylabel("Latency (Tok/s)")

            axs[i].set_xlabel("Gamma ($\gamma$)")
            axs[i].set_title(f"Dataset: {dataset}")
            axs[i].grid(True, axis="y")
    plt.suptitle(f"Latency on {gpu}")
    plt.tight_layout()
    plt.figlegend(
        lines,
        get_pretty_names(modes),
        loc="upper center",
        bbox_to_anchor=(0.5, -0.05),
        ncol=len(modes) / 2,
    )
    plt.savefig(
        f"figures/all_d{draft_model}_t{target_model}_{gpu}_t{temperature}_latency_total_all.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.clf()
    matplotlib.pyplot.close()

    ########## ACCEPTANCE PLOTS
    # now do the same for acceptance rates
    fig, axs = plt.subplots(1, 4, figsize=(16, 6))
    axs = axs.ravel()
    all_values = []
    lines = []

    for i, dataset in enumerate(datasets):
        results = get_results(
            modes, gammas, dataset, gpu, temperature, draft_model, target_model
        )

        for label, color, linestyle in zip(modes, colors, linestyles):
            (line,) = axs[i].plot(
                gammas,
                results[label]["acceptance_rate_mean"],
                label=label,
                marker="o",
                color=color,
            )
            all_values.extend(results[label]["acceptance_rate_mean"])
            if i == 0:
                lines.append(line)
                axs[i].set_ylabel("Acceptance Rate")

            print(
                f"Acceptance for {gpu} for {draft_model} to {target_model} on {dataset}: {label} {max(results[label]['acceptance_rate_mean'])} {min(results[label]['acceptance_rate_mean'])} {np.argmax(results[label]['acceptance_rate_mean'])+1} {np.argmin(results[label]['acceptance_rate_mean'])+1}"
            )

        axs[i].set_xlabel("Gamma ($\gamma$)")
        axs[i].set_title(f"Dataset: {dataset}")
        axs[i].grid(True, axis="y")

    plt.suptitle(f"Acceptance Rate on {gpu}")

    all_values = np.array(all_values)
    all_values = all_values[~np.isnan(all_values)]
    y_min = min(all_values)
    for ax in axs:
        ax.set_ylim(y_min - 0.05, 1.05)

    # only set ytick labels for the first plot
    for ax in axs[1:]:
        ax.set_yticklabels([])

    plt.tight_layout()
    plt.figlegend(
        lines,
        get_pretty_names(modes),
        loc="upper center",
        bbox_to_anchor=(0.5, -0.05),
        ncol=len(modes) / 2,
    )
    plt.savefig(
        f"figures/all_d{draft_model}_t{target_model}_{gpu}_t{temperature}_acceptance_all.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.clf()
    matplotlib.pyplot.close()


def plot_speedup_all(gpu, temperature, draft_model, target_model):
    datasets = ["gsm8k", "finance-alpaca", "xsum", "openai_humaneval"]
    fig, axs = plt.subplots(1, 4, figsize=(16, 6))
    axs = axs.ravel()
    gammas = [1, 2, 3, 4, 5, 6, 7, 8]
    modes = ["ars", "sps", "usps", "dsps", "dhsps", "pp"]
    colors = ["blue", "green", "purple", "red", "orange", "black"]
    linestyles = ["solid", "solid", "solid", "solid", "solid", "dashed"]
    lines = []
    dottype = ["o", "x", "s", "D", "P", "X"]

    all_values = []
    scatter = {}

    print(f"{gpu} for {draft_model} to {target_model}")
    for i, dataset in enumerate(datasets):
        results_table = {}
        for mode in modes:
            if mode == "ars":
                continue
            best_gammas = []
            best_perfs = []
            results = []
            for gamma in gammas:
                # print(mode, dataset, draft_model, target_model, gpu, 0.0, gamma)
                result = get_df(
                    mode, dataset, draft_model, target_model, gpu, 0.0, gamma
                )
                results.append(result)

            # there are 25 samples, for each sample, find the best gamma
            best_static_gamma = [
                result["total_tok_per_sec"].mean() for result in results
            ]
            best_static_gamma = np.argmax(best_static_gamma) + 1
            speedups = []
            for i in range(25):
                prompt_speeds = []
                for gamma in gammas:
                    prompt_speeds.append(
                        results[gamma - 1]["total_tok_per_sec"].values[i]
                    )
                speedup = max(prompt_speeds) / prompt_speeds[best_static_gamma - 1]
                speedups.append(speedup)
            # print(speedups)
            speedups_mean = gmean(np.array(speedups))

            if mode == "ars":  # or mode == "usps":
                continue
            # if gmean(results_table[mode] / results_table["sps"]) > 1.0:
            if mode == "sps":
                print(
                    f"{best_static_gamma} speedup for {dataset} {get_pretty_names([mode])[0]}: ",
                    speedups_mean,
                )
            scatter[dataset + mode] = {
                "prefill": results[best_static_gamma - 1]["prefill_tokens"].values,
                "generate": results[best_static_gamma - 1]["generate_tokens"].values,
                "total tok/s": results[best_static_gamma - 1][
                    "total_tok_per_sec"
                ].values,
                "gen tok/s": results[best_static_gamma - 1][
                    "generate_tok_per_sec"
                ].values,
            }

    # make a scatter plot, with the x-axis being the prefill tokens, and the y-axis being the speedup
    fig, axs = plt.subplots(1, 4, figsize=(16, 6))
    axs = axs.ravel()
    for i, dataset in enumerate(datasets):
        for j, mode in enumerate(modes):
            if mode == "ars":
                continue
            x = scatter[dataset + mode]["prefill"]
            y = scatter[dataset + mode]["total tok/s"]

            # Fit a line to the data
            coeffs = np.polyfit(x, y, 1)
            fit_function = np.poly1d(coeffs)

            # Plot the line of best fit
            axs[i].plot(
                x, fit_function(x), color=colors[j], alpha=0.5, label="_nolegend_"
            )
            axs[i].scatter(x, y, label=mode, marker=dottype[j], color=colors[j])
            axs[i].set_xlabel("Prompt Token Length")
            axs[i].set_title(f"Dataset: {dataset}")
            axs[i].grid(True)

    axs[0].set_ylabel("Total Tok/s")
    plt.suptitle(f"Scatterplot of Total Tok/s per Prompt Token Length on {gpu}")
    modified_modes = ["sps", "usps", "dsps", "dhsps", "pp"]
    plt.figlegend(
        get_pretty_names(modified_modes),
        loc="upper center",
        bbox_to_anchor=(0.5, -0.05),
        ncol=len(modified_modes) / 2,
    )

    plt.tight_layout()
    plt.savefig(
        f"figures/all_d{draft_model}_t{target_model}_{gpu}_t{temperature}_scatter.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.clf()
    matplotlib.pyplot.close()

    fig, axs = plt.subplots(1, 4, figsize=(16, 6))
    axs = axs.ravel()
    for i, dataset in enumerate(datasets):
        for j, mode in enumerate(modes):
            if mode == "ars":
                continue
            x = scatter[dataset + mode]["prefill"]
            y = scatter[dataset + mode]["gen tok/s"]

            # Fit a line to the data
            coeffs = np.polyfit(x, y, 1)
            fit_function = np.poly1d(coeffs)

            # Plot the line of best fit
            axs[i].plot(
                x, fit_function(x), color=colors[j], alpha=0.5, label="_nolegend_"
            )
            axs[i].scatter(x, y, label=mode, marker=dottype[j], color=colors[j])
            axs[i].set_title(f"Dataset: {dataset}")
            axs[i].grid(True)
            axs[i].set_xlabel("Prompt Token Length")

    axs[0].set_ylabel("Gen Tok/s")

    plt.suptitle(f"Scatterplot of Gen Tok/s per Prompt Token Length on {gpu}")
    modified_modes = ["sps", "usps", "dsps", "dhsps", "pp"]
    plt.figlegend(
        get_pretty_names(modified_modes),
        loc="upper center",
        bbox_to_anchor=(0.5, -0.05),
        ncol=len(modified_modes) / 2,
    )

    plt.tight_layout()
    plt.savefig(
        f"figures/all_d{draft_model}_t{target_model}_{gpu}_t{temperature}_scatter_gentoks.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.clf()
    matplotlib.pyplot.close()

    fig, axs = plt.subplots(1, 4, figsize=(16, 6))
    axs = axs.ravel()
    for i, dataset in enumerate(datasets):
        for j, mode in enumerate(modes):
            if mode == "ars":
                continue
            x = scatter[dataset + mode]["generate"]
            y = scatter[dataset + mode]["total tok/s"]

            # Fit a line to the data
            coeffs = np.polyfit(x, y, 1)
            fit_function = np.poly1d(coeffs)

            # Plot the line of best fit
            # no legend for line of best fit
            axs[i].plot(
                x, fit_function(x), color=colors[j], alpha=0.5, label="_nolegend_"
            )
            axs[i].scatter(x, y, label=mode, marker=dottype[j], color=colors[j])
            axs[i].set_xlabel("Generate Token Length")
            axs[i].set_title(f"Dataset: {dataset}")
            axs[i].grid(True)

    axs[0].set_ylabel("Total Tok/s")
    plt.suptitle(f"Scatterplot of Total Tok/s per Prompt Token Length on {gpu}")
    modified_modes = ["sps", "usps", "dsps", "dhsps", "pp"]
    plt.figlegend(
        get_pretty_names(modified_modes),
        loc="upper center",
        bbox_to_anchor=(0.5, -0.05),
        ncol=len(modified_modes) / 2,
    )

    plt.tight_layout()
    plt.savefig(
        f"figures/all_d{draft_model}_t{target_model}_{gpu}_t{temperature}_scatter_generate.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.clf()
    matplotlib.pyplot.close()


# "NVIDIA_RTX_A6000", "Tesla_V100-SXM2-32GB"
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=str, default="Tesla_V100-SXM2-32GB")
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--draft_model", type=str, default="TinyLlama-1.1B-Chat-v1.0")
    parser.add_argument("--target_model", type=str, default="Llama-2-13b-chat-hf")
    parser.add_argument("--dataset", type=str, default="gsm8k")

    args = parser.parse_args()

    ###############################################################
    plot_speedup_all(
        "NVIDIA_A100-SXM4-40GB",
        args.temperature,
        "TinyLlama-1.1B-Chat-v1.0",
        "Llama-2-13b-chat-hf",
    )
    plot_speedup_all(
        "NVIDIA_A100-SXM4-40GB",
        args.temperature,
        "stablelm-base-alpha-3b-v2",
        "stablelm-base-alpha-7b-v2",
    )
    plot_speedup_all("NVIDIA_A100-SXM4-40GB", args.temperature, "opt-125m", "opt-13b")
    plot_speedup_all("NVIDIA_A100-SXM4-40GB", args.temperature, "gemma-2b", "gemma-7b")
    plot_speedup_all(
        "NVIDIA_A100-SXM4-40GB", args.temperature, "bloom-560m", "bloom-7b1"
    )
    plot_speedup_all(
        "NVIDIA_A100-SXM4-40GB", args.temperature, "open_llama_3b", "open_llama_7b"
    )
    plot_speedup_all(
        "NVIDIA_A100-SXM4-40GB", args.temperature, "dolly-v2-3b", "dolly-v2-12b"
    )

    ###############################################################
    plot_speedup_all(
        "NVIDIA_A40",
        args.temperature,
        "TinyLlama-1.1B-Chat-v1.0",
        "Llama-2-13b-chat-hf",
    )
    plot_speedup_all(
        "NVIDIA_A40",
        args.temperature,
        "stablelm-base-alpha-3b-v2",
        "stablelm-base-alpha-7b-v2",
    )
    plot_speedup_all("NVIDIA_A40", args.temperature, "opt-125m", "opt-13b")
    plot_speedup_all("NVIDIA_A40", args.temperature, "gemma-2b", "gemma-7b")
    plot_speedup_all("NVIDIA_A40", args.temperature, "bloom-560m", "bloom-7b1")
    plot_speedup_all("NVIDIA_A40", args.temperature, "open_llama_3b", "open_llama_7b")
    plot_speedup_all("NVIDIA_A40", args.temperature, "dolly-v2-3b", "dolly-v2-12b")

    ###############################################################
    # plot_speedup_all(args.gpu, args.temperature, "TinyLlama-1.1B-Chat-v1.0", "Llama-2-13b-chat-hf")
    plot_speedup_all(
        args.gpu,
        args.temperature,
        "stablelm-base-alpha-3b-v2",
        "stablelm-base-alpha-7b-v2",
    )
    plot_speedup_all(args.gpu, args.temperature, "opt-125m", "opt-13b")
    plot_speedup_all(args.gpu, args.temperature, "gemma-2b", "gemma-7b")
    plot_speedup_all(args.gpu, args.temperature, "bloom-560m", "bloom-7b1")
    plot_speedup_all(args.gpu, args.temperature, "open_llama_3b", "open_llama_7b")
    # plot_speedup_all(args.gpu, args.temperature, "dolly-v2-3b", "dolly-v2-12b")
    ###############################################################

    # plot_all(args.gpu, args.temperature, "TinyLlama-1.1B-Chat-v1.0", "Llama-2-13b-chat-hf")
    plot_all(args.gpu, args.temperature, "opt-125m", "opt-13b")
    plot_all(args.gpu, args.temperature, "gemma-2b", "gemma-7b")
    # plot_all(args.gpu, args.temperature, "gemma-1.1-2b-it", "gemma-1.1-7b-it")
    plot_all(args.gpu, args.temperature, "bloom-560m", "bloom-7b1")
    plot_all(args.gpu, args.temperature, "open_llama_3b", "open_llama_7b")
    # plot_all(args.gpu, args.temperature, "dolly-v2-3b", "dolly-v2-12b")
    # plot_all(args.gpu, args.temperature, "llama-160m", "vicuna-13b-v1.5")
    plot_all(
        args.gpu,
        args.temperature,
        "TinyLlama-1.1B-Chat-v1.0",
        "vicuna-13b-v1.5",
    )
    plot_all(
        args.gpu,
        args.temperature,
        "stablelm-base-alpha-3b-v2",
        "stablelm-base-alpha-7b-v2",
    )

    plot_all(
        "NVIDIA_A40",
        args.temperature,
        "TinyLlama-1.1B-Chat-v1.0",
        "Llama-2-13b-chat-hf",
    )
    plot_all("NVIDIA_A40", args.temperature, "opt-125m", "opt-13b")
    plot_all("NVIDIA_A40", args.temperature, "gemma-2b", "gemma-7b")
    plot_all("NVIDIA_A40", args.temperature, "gemma-1.1-2b-it", "gemma-1.1-7b-it")
    plot_all("NVIDIA_A40", args.temperature, "bloom-560m", "bloom-7b1")
    plot_all("NVIDIA_A40", args.temperature, "open_llama_3b", "open_llama_7b")
    plot_all("NVIDIA_A40", args.temperature, "dolly-v2-3b", "dolly-v2-12b")
    plot_all("NVIDIA_A40", args.temperature, "llama-160m", "vicuna-13b-v1.5")
    plot_all(
        "NVIDIA_A40",
        args.temperature,
        "TinyLlama-1.1B-Chat-v1.0",
        "vicuna-13b-v1.5",
    )
    plot_all(
        "NVIDIA_A40",
        args.temperature,
        "stablelm-base-alpha-3b-v2",
        "stablelm-base-alpha-7b-v2",
    )

    plot_all(
        "NVIDIA_A100-SXM4-40GB",
        args.temperature,
        "TinyLlama-1.1B-Chat-v1.0",
        "Llama-2-13b-chat-hf",
    )
    plot_all("NVIDIA_A100-SXM4-40GB", args.temperature, "opt-125m", "opt-13b")
    plot_all("NVIDIA_A100-SXM4-40GB", args.temperature, "gemma-2b", "gemma-7b")
    plot_all("NVIDIA_A100-SXM4-40GB", args.temperature, "bloom-560m", "bloom-7b1")
    plot_all(
        "NVIDIA_A100-SXM4-40GB", args.temperature, "open_llama_3b", "open_llama_7b"
    )
    plot_all("NVIDIA_A100-SXM4-40GB", args.temperature, "dolly-v2-3b", "dolly-v2-12b")
    plot_all("NVIDIA_A100-SXM4-40GB", args.temperature, "llama-160m", "vicuna-13b-v1.5")
    plot_all(
        "NVIDIA_A100-SXM4-40GB",
        args.temperature,
        "TinyLlama-1.1B-Chat-v1.0",
        "vicuna-13b-v1.5",
    )
    plot_all(
        "NVIDIA_A100-SXM4-40GB",
        args.temperature,
        "stablelm-base-alpha-3b-v2",
        "stablelm-base-alpha-7b-v2",
    )

    # datasets = ["gsm8k", "finance-alpaca", "xsum", "openai_humaneval"]
    # for dataset in datasets:
    #     # plot(args.gpu, args.temperature, "opt-125m", "opt-13b", dataset)
    #     # plot(args.gpu, args.temperature, "gemma-2b", "gemma-7b", dataset)
    #     # plot(args.gpu, args.temperature, "bloom-560m", "bloom-7b1", dataset)
    #     # plot(args.gpu, args.temperature, "open_llama_3b", "open_llama_7b", dataset)
    #     # plot(args.gpu, args.temperature, "dolly-v2-3b", "dolly-v2-12b", dataset)
    #     # plot(args.gpu, args.temperature, "llama-160m", "vicuna-13b-v1.5", dataset)
    #     # plot(args.gpu, args.temperature, "TinyLlama-1.1B-Chat-v1.0", "vicuna-13b-v1.5", dataset)
    #     # plot(
    #     #     args.gpu,
    #     #     args.temperature,
    #     #     "stabilityai/stablelm-base-alpha-3b-v2",
    #     #     "stabilityai/stablelm-base-alpha-7b-v2",
    #     #     dataset,
    #     # )
    #     plot("NVIDIA_A40", args.temperature, "opt-125m", "opt-13b", dataset)
    #     # plot("NVIDIA_A40", args.temperature, "gemma-1.1-2b-it", "gemma-1.1-7b-it", dataset)
    #     plot("NVIDIA_A40", args.temperature, "gemma-2b", "gemma-7b", dataset)
    #     plot("NVIDIA_A40", args.temperature, "bloom-560m", "bloom-7b1", dataset)
    #     plot("NVIDIA_A40", args.temperature, "open_llama_3b", "open_llama_7b", dataset)
    #     plot("NVIDIA_A40", args.temperature, "dolly-v2-3b", "dolly-v2-12b", dataset)
    #     plot("NVIDIA_A40", args.temperature, "llama-160m", "vicuna-13b-v1.5", dataset)
    #     plot(
    #         "NVIDIA_A40",
    #         args.temperature,
    #         "TinyLlama-1.1B-Chat-v1.0",
    #         "vicuna-13b-v1.5",
    #         dataset,
    #     )
    #     plot(
    #         "NVIDIA_A40",
    #         args.temperature,
    #         "stablelm-base-alpha-3b-v2",
    #         "stablelm-base-alpha-7b-v2",
    #         dataset,
    #     )
    #     plot("NVIDIA_A100-SXM4-40GB", args.temperature, "opt-125m", "opt-13b", dataset)
    #     # plot(
    #     #     "NVIDIA_A100-SXM4-40GB",
    #     #     args.temperature,
    #     #     "gemma-1.1-2b-it",
    #     #     "gemma-1.1-7b-it",
    #     #     dataset,
    #     # )
    #     plot("NVIDIA_A100-SXM4-40GB", args.temperature, "gemma-2b", "gemma-7b", dataset)
    #     plot(
    #         "NVIDIA_A100-SXM4-40GB",
    #         args.temperature,
    #         "bloom-560m",
    #         "bloom-7b1",
    #         dataset,
    #     )
    #     plot(
    #         "NVIDIA_A100-SXM4-40GB",
    #         args.temperature,
    #         "open_llama_3b",
    #         "open_llama_7b",
    #         dataset,
    #     )
    #     plot(
    #         "NVIDIA_A100-SXM4-40GB",
    #         args.temperature,
    #         "dolly-v2-3b",
    #         "dolly-v2-12b",
    #         dataset,
    #     )
    #     plot(
    #         "NVIDIA_A100-SXM4-40GB",
    #         args.temperature,
    #         "llama-160m",
    #         "vicuna-13b-v1.5",
    #         dataset,
    #     )
    #     plot(
    #         "NVIDIA_A100-SXM4-40GB",
    #         args.temperature,
    #         "TinyLlama-1.1B-Chat-v1.0",
    #         "vicuna-13b-v1.5",
    #         dataset,
    #     )
    #     plot(
    #         "NVIDIA_A100-SXM4-40GB",
    #         args.temperature,
    #         "stablelm-base-alpha-3b-v2",
    #         "stablelm-base-alpha-7b-v2",
    #         dataset,
    #     )
