import sklearn.metrics as metrics
import pandas as pd
import torch
from sklearn.metrics import confusion_matrix
from typing import List, Dict
import numpy as np
from matplotlib import pyplot as plt
import argparse
import os
from typing import Tuple


def compute_accuracy(data: Dict) -> torch.Tensor:
    epoch_keys = sorted(list(data[list(data.keys())[0]].keys()))
    seed_keys = sorted(list(data.keys()))
    k_keys = sorted(list(data[list(data.keys())[0]][epoch_keys[0]].keys()))
    all_accuracies = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    for i_epoch, epoch in enumerate(epoch_keys):
        for i_seed, seed in enumerate(seed_keys):
            for i_key, k in enumerate(k_keys):
                predict_labels = data[seed][epoch][k]["predicted_label"]
                gt_labels = data[seed][epoch][k]["gt_label"]
                comp = [x == y for x, y in zip(predict_labels, gt_labels)]
                accuracy = sum(comp) / len(comp)
                all_accuracies[i_epoch, i_key, i_seed] = accuracy

    return all_accuracies


def compute_confusion_matrix(data: Dict) -> List[torch.Tensor]:
    epoch_keys = list(data[list(data.keys())[0]].keys())
    seed_keys = list(data.keys())
    k_keys = list(data[list(data.keys())[0]][epoch_keys[0]].keys())
    true_positive = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    true_negative = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    false_positive = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    false_negative = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    accuracies = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))

    for i_epoch, epoch in enumerate(epoch_keys):
        for i_seed, seed in enumerate(seed_keys):
            for i_key, k in enumerate(data[seed][epoch].keys()):
                predict_labels = data[seed][epoch][k]["predicted_label"]
                predict_scores = data[seed][epoch][k]["predicted_scores"]
                predict_labels_per_score = []
                for x in predict_scores:
                    if x[0] > x[2]:
                        predict_labels_per_score.append("positive")
                    else:
                        predict_labels_per_score.append("negative")

                gt_labels = data[seed][epoch][k]["gt_label"]
                tn, fp, fn, tp = confusion_matrix(
                    gt_labels, predict_labels_per_score, labels=["negative", "positive"]
                ).ravel()
                true_positive[i_epoch, i_key, i_seed] = tp / (tp + fn)
                true_negative[i_epoch, i_key, i_seed] = tn / (tn + fp)
                false_positive[i_epoch, i_key, i_seed] = fp / (fp + tn)
                false_negative[i_epoch, i_key, i_seed] = fn / (fn + tp)
                accuracies[i_epoch, i_key, i_seed] = (tp + tn) / (tp + tn + fp + fn)

    return true_positive, true_negative, false_positive, false_negative, accuracies


def compute_confusion_matrix_label(data: Dict) -> List[torch.Tensor]:
    epoch_keys = list(data[list(data.keys())[0]].keys())
    seed_keys = list(data.keys())
    k_keys = list(data[list(data.keys())[0]][epoch_keys[0]].keys())
    true_positive = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    true_negative = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    false_positive = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    false_negative = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    null_positive = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))
    null_negative = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))

    for i_epoch, epoch in enumerate(epoch_keys):
        for i_seed, seed in enumerate(seed_keys):
            for i_key, k in enumerate(data[seed][epoch].keys()):
                predict_labels = data[seed][epoch][k]["predicted_label"]
                predict_labels_three_class = []
                for x in predict_labels:
                    if x.strip() not in ["positive", "negative"]:
                        predict_labels_three_class.append("null")
                    else:
                        predict_labels_three_class.append(x.strip())

                gt_labels = data[seed][epoch][k]["gt_label"]
                tn, fp, en, fn, tp, ep, _, _, _ = confusion_matrix(
                    gt_labels,
                    predict_labels_three_class,
                    labels=["negative", "positive", "null"],
                ).ravel()
                num_pos = tp + fn + ep
                num_neg = tn + fp + en

                true_positive[i_epoch, i_key, i_seed] = tp / num_pos
                true_negative[i_epoch, i_key, i_seed] = tn / num_neg
                null_positive[i_epoch, i_key, i_seed] = ep / num_pos
                false_positive[i_epoch, i_key, i_seed] = fp / num_neg
                false_negative[i_epoch, i_key, i_seed] = fn / num_pos
                null_negative[i_epoch, i_key, i_seed] = en / num_neg

    return (
        true_positive,
        true_negative,
        false_positive,
        false_negative,
        null_positive,
        null_negative,
    )


def ensemble_accuracy(datasets: List[Dict]) -> torch.Tensor:
    d = datasets[0]
    epoch_keys = list(d[list(d.keys())[0]].keys())
    seed_keys = list(d.keys())
    k_keys = list(d[list(d.keys())[0]][epoch_keys[0]].keys())
    all_accuracies = torch.zeros(len(epoch_keys), len(k_keys), len(seed_keys))

    label_map = {"positive": 1, "negative": 0}
    label_map_inverse = {1: "positive", 0: "negative", 2: "null"}

    for i_epoch, epoch in enumerate(epoch_keys):
        for i_seed, seed in enumerate(seed_keys):
            for i_key, k in enumerate(d[seed][epoch].keys()):
                preds = []
                for data in datasets:
                    new_labels = []
                    for pl in data[seed][epoch][k]["predicted_label"]:
                        if pl.strip() not in label_map.keys():
                            new_labels.append(2)
                        else:
                            new_labels.append(label_map[pl.strip()])
                    preds.append(np.array(new_labels))

                # preds = [
                #     np.array(
                #         [label_map[x] for x in data[seed][epoch][k]["predicted_label"]]
                #     )
                #     for data in datasets
                # ]
                # get majority vote
                majority_vote = np.apply_along_axis(
                    lambda x: np.bincount(x).argmax(), axis=0, arr=np.array(preds)
                )
                predict_labels = [label_map_inverse[x] for x in majority_vote]
                gt_labels = d[seed][epoch][k]["gt_label"]
                comp = [x == y for x, y in zip(predict_labels, gt_labels)]
                accuracy = sum(comp) / len(comp)
                all_accuracies[i_epoch, i_key, i_seed] = accuracy

    return all_accuracies


def plot_accuracy(
    lr_solver: torch.Tensor = None,
    ft_full_accuracies: torch.Tensor = None,
    ft_layers_accuracies: torch.Tensor = None,
    ft_head_accuracies: torch.Tensor = None,
    ft_adaptor_accuracies: torch.Tensor = None,
    tuned_accuracies: torch.Tensor = None,
    original_accuracies: torch.Tensor = None,
    lr_adaptor_corr_accuracies: torch.Tensor = None,
    lr_adaptor_noncorr_accuracies: torch.Tensor = None,
    lr_adaptor_corr_repeat_accuracies: torch.Tensor = None,
    solver_type: str = "saga",
    penalty: str = "l2",
    k_range: List[int] = [4, 8, 16, 32, 48, 64],
    dataset: str = "SMS",
    plot_output_dir: str = "plots",
    plot_prefix: str = "acc",
    y_lim: Tuple[float, float] = (0.5, 1.0),
):
    # k_range = [18, 32, 48, 58]  # [48]  # [32, 48, 64, 128]
    # add ft_accuracies
    inputs = ""
    y_values = []
    base_color = (0, 119 / 255, 190 / 255)

    if lr_solver is not None:
        inputs += "_LRS"
        # lr_solver shape = (1, num_k, 5), index into num_k by 2
        # lr_solver = lr_solver[:, 2:, :]
        lr_solver = lr_solver[:, : len(k_range), :]
        ls_acc_mean = torch.mean(lr_solver, dim=-1)
        ls_acc_std = torch.std(lr_solver, dim=-1)

        y_upper = ls_acc_mean + ls_acc_std
        y_lower = ls_acc_mean - ls_acc_std

        colors = ["0.0", "0.2", "0.4", "0.6", "0.7"]
        # reverse list
        colors = colors[::-1]
        for epoch in range(ls_acc_mean.shape[0]):
            y_values.append(
                (
                    ls_acc_mean[epoch, :].tolist(),
                    f"LR Solver, Type: {solver_type} Penalty: {penalty}))",
                    # f"LR tuned (Repeat ICE 4x, swapping order)"
                    # "LR Tuned (trained sequences with 32 ICE)",
                    "tab:orange",
                    y_upper[epoch, :].tolist(),
                    y_lower[epoch, :].tolist(),
                )
            )

    if ft_full_accuracies is not None:
        inputs += "_FullFT"
        # ft_accuracies = ft_accuracies[:, 2:, :]
        # ft_accuracies = ft_accuracies[:, : len(k_range), :]
        ft_acc_mean = torch.mean(ft_full_accuracies, dim=-1)
        ft_acc_std = torch.std(ft_full_accuracies, dim=-1)

        y_upper = ft_acc_mean + ft_acc_std
        y_lower = ft_acc_mean - ft_acc_std

        colors = ["0.0", "0.2", "0.4", "0.6", "0.7"]
        # reverse list
        colors = colors[::-1]
        for epoch in range(ft_acc_mean.shape[0]):
            if epoch == 0:
                y_values.append(
                    (
                        ft_acc_mean[epoch, :].tolist(),
                        # f"SGD Tuned Model (Epoch {epoch+1}))",
                        f"Fully FT Model",
                        colors[len(colors) - 1],
                        y_upper[epoch, :].tolist(),
                        y_lower[epoch, :].tolist(),
                    )
                )

    if ft_layers_accuracies is not None:
        inputs += "_FTLayers"
        ft_acc_mean = torch.mean(ft_layers_accuracies, dim=-1)
        ft_acc_std = torch.std(ft_layers_accuracies, dim=-1)
        y_upper = ft_acc_mean + ft_acc_std
        y_lower = ft_acc_mean - ft_acc_std

        colors = ["0.0", "0.2", "0.4", "0.6", "0.7"]
        # reverse list
        colors = colors[::-1]
        for epoch in range(ft_acc_mean.shape[0]):
            if epoch == 0:
                y_values.append(
                    (
                        ft_acc_mean[epoch, :].tolist(),
                        # f"SGD Tuned Model (Epoch {epoch+1}))",
                        f"FT Layers Model (n_layers=2)",
                        colors[len(colors) - 2],
                        y_upper[epoch, :].tolist(),
                        y_lower[epoch, :].tolist(),
                    )
                )

    if ft_adaptor_accuracies is not None:
        inputs += "_FTADap"
        ft_acc_mean = torch.mean(ft_adaptor_accuracies, dim=-1)
        ft_acc_std = torch.std(ft_adaptor_accuracies, dim=-1)
        y_upper = ft_acc_mean + ft_acc_std
        y_lower = ft_acc_mean - ft_acc_std

        colors = ["0.0", "0.2", "0.4", "0.6", "0.7"]
        # reverse list
        colors = colors[::-1]
        for epoch in range(ft_acc_mean.shape[0]):
            if epoch == 0:
                y_values.append(
                    (
                        ft_acc_mean[epoch, :].tolist(),
                        # f"SGD Tuned Model (Epoch {epoch+1}))",
                        f"FT Adaptor Model",
                        colors[len(colors) - 3],
                        y_upper[epoch, :].tolist(),
                        y_lower[epoch, :].tolist(),
                    )
                )

    if ft_head_accuracies is not None:
        inputs += "_FTHead"
        ft_acc_mean = torch.mean(ft_head_accuracies, dim=-1)
        ft_acc_std = torch.std(ft_head_accuracies, dim=-1)
        y_upper = ft_acc_mean + ft_acc_std
        y_lower = ft_acc_mean - ft_acc_std

        colors = ["0.0", "0.2", "0.4", "0.6", "0.7"]
        # reverse list
        colors = colors[::-1]
        for epoch in range(ft_acc_mean.shape[0]):
            if epoch == 0:
                y_values.append(
                    (
                        ft_acc_mean[epoch, :].tolist(),
                        # f"SGD Tuned Model (Epoch {epoch+1}))",
                        f"FT LM Model",
                        colors[len(colors) - 4],
                        y_upper[epoch, :].tolist(),
                        y_lower[epoch, :].tolist(),
                    )
                )

        # for each epoch, plot the accuracy for each k, and label each plot with the epoch, add color
    if tuned_accuracies is not None:
        inputs += "_LR"
        # t  # uned_accuracies = tuned_accuracies[:, 2:, :]
        # tuned_accuracies = tuned_accuracies[:, : len(k_range), :]
        tuned_acc_mean = torch.mean(tuned_accuracies, dim=-1)
        tuned_acc_std = torch.std(tuned_accuracies, dim=-1)

        y_upper = tuned_acc_mean + tuned_acc_std
        y_lower = tuned_acc_mean - tuned_acc_std
        for epoch in range(tuned_acc_mean.shape[0]):
            y_values.append(
                (
                    tuned_acc_mean[epoch, :].tolist(),
                    f"Logistic Regression Adaptor",
                    "tab:blue",
                    y_upper[epoch, :].tolist(),
                    y_lower[epoch, :].tolist(),
                )
            )
        # for each epoch, plot the accuracy for each k, and label each plot with the epoch, add color

    if original_accuracies is not None:
        inputs += "_OG"
        # original_accuracies = original_accuracies[:, 2:, :]
        # original_accuracies = original_accuracies[:, : len(k_range), :]
        original_acc_mean = torch.mean(original_accuracies, dim=-1)
        original_acc_std = torch.std(original_accuracies, dim=-1)

        y_upper = original_acc_mean + original_acc_std
        y_lower = original_acc_mean - original_acc_std
        for epoch in range(original_acc_mean.shape[0]):
            y_values.append(
                (
                    original_acc_mean[epoch, :].tolist(),
                    f"Original Model",
                    "tab:red",
                    y_upper[epoch, :].tolist(),
                    y_lower[epoch, :].tolist(),
                )
            )

    if lr_adaptor_corr_accuracies is not None:
        inputs += "_LRADC"
        # lr_adaptor_accuracies = lr_adaptor_accuracies[:, : len(k_range), :]
        lr_adaptor_acc_mean = torch.mean(lr_adaptor_corr_accuracies, dim=-1)
        lr_adaptor_acc_std = torch.std(lr_adaptor_corr_accuracies, dim=-1)

        y_upper = lr_adaptor_acc_mean + lr_adaptor_acc_std
        y_lower = lr_adaptor_acc_mean - lr_adaptor_acc_std
        color = tuple(np.array(base_color) + (2 * 0.1))
        for epoch in range(lr_adaptor_acc_mean.shape[0]):
            y_values.append(
                (
                    lr_adaptor_acc_mean[epoch, :].tolist(),
                    f"LR Adaptor (corrupted naive)",
                    color,
                    y_upper[epoch, :].tolist(),
                    y_lower[epoch, :].tolist(),
                )
            )

    if lr_adaptor_noncorr_accuracies is not None:
        inputs += "_LRADNC"
        lr_adaptor_noncorr_accuracies = lr_adaptor_noncorr_accuracies[
            :, : len(k_range), :
        ]
        lr_adaptor_acc_mean = torch.mean(lr_adaptor_noncorr_accuracies, dim=-1)
        lr_adaptor_acc_std = torch.std(lr_adaptor_noncorr_accuracies, dim=-1)

        y_upper = lr_adaptor_acc_mean + lr_adaptor_acc_std
        y_lower = lr_adaptor_acc_mean - lr_adaptor_acc_std
        color = tuple(np.array(base_color) + (0 * 0.1))
        for epoch in range(lr_adaptor_acc_mean.shape[0]):
            y_values.append(
                (
                    lr_adaptor_acc_mean[epoch, :].tolist(),
                    f"LR Adaptor (non-corrupted)",
                    color,
                    y_upper[epoch, :].tolist(),
                    y_lower[epoch, :].tolist(),
                )
            )

    if lr_adaptor_corr_repeat_accuracies is not None:
        inputs += "_LRADCR"
        lr_adaptor_acc_mean = torch.mean(lr_adaptor_corr_repeat_accuracies, dim=-1)
        lr_adaptor_acc_std = torch.std(lr_adaptor_corr_repeat_accuracies, dim=-1)
        y_upper = lr_adaptor_acc_mean + lr_adaptor_acc_std
        y_lower = lr_adaptor_acc_mean - lr_adaptor_acc_std
        color = tuple(np.array(base_color) + (1 * 0.1))
        for epoch in range(lr_adaptor_acc_mean.shape[0]):
            y_values.append(
                (
                    lr_adaptor_acc_mean[epoch, :].tolist(),
                    f"LR Adaptor (corrupted clever)",
                    color,
                    y_upper[epoch, :].tolist(),
                    y_lower[epoch, :].tolist(),
                )
            )

    for i, y in enumerate(y_values):
        plt.plot(k_range, y[0], label=y[1], color=y[2])
        # plt.fill_between(k_range, y[3], y[4], color=y[2], alpha=0.2)

    plt.xlabel("Number of In-Context Examples")
    plt.title(f"Accuracy for {dataset} Dataset")
    plt.ylabel("Accuracy")
    plt.ylim(y_lim[0], y_lim[1])
    plt.xscale('log', base=2)
    # plt.legend()
    # plt.legend(loc="upper left", title="Metric")
    plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    plt.show()

    # save the plot
    plt.savefig(
        f"{plot_output_dir}/{dataset}{inputs}_accuracy_{plot_prefix}.png",
        bbox_inches="tight",
    )
    plt.clf()


def plot_generic(
    ft_outputs: torch.Tensor = None,
    tuned_outputs: torch.Tensor = None,
    original_outputs: torch.Tensor = None,
    k_range: List[int] = [4, 8, 16, 32, 48, 64],
    dataset: str = "SMS",
    metric: str = "accuracy",
    plot_output_dir: str = "plots",
):
    inputs = ""
    y_values = []
    if ft_outputs is not None:
        inputs += "FT"
        colors = ["0.0", "0.2", "0.4", "0.6", "0.7"]
        # reverse list
        colors = colors[::-1]
        ft_outputs = torch.mean(ft_outputs, dim=-1)
        for epoch in range(ft_outputs.shape[0]):
            y_values.append(
                (
                    ft_outputs[epoch, :].tolist(),
                    f"Finetuned Model (Epoch {epoch+1}))",
                    colors[epoch],
                )
            )
        # for each epoch, plot the accuracy for each k, and label each plot with the epoch, add color
    if tuned_outputs is not None:
        inputs += "_LR"
        tuned_outputs = torch.mean(tuned_outputs, dim=-1)
        for epoch in range(tuned_outputs.shape[0]):
            y_values.append(
                (tuned_outputs[epoch, :].tolist(), f"LR Tuned Model", "tab:blue")
            )
        # for each epoch, plot the accuracy for each k, and label each plot with the epoch, add color

    if original_outputs is not None:
        inputs += "_OG"
        original_outputs = torch.mean(original_outputs, dim=-1)
        for epoch in range(original_outputs.shape[0]):
            y_values.append(
                (original_outputs[epoch, :].tolist(), f"Original Model", "tab:red")
            )

    for i, y in enumerate(y_values):
        plt.plot(k_range, y[0], label=y[1], color=y[2])

    plt.xlabel("Number of In-Context Examples")
    plt.title(f"{metric} Rate for {dataset} (2-class) Dataset")
    plt.ylabel(f"{metric}")
    plt.legend()
    plt.show()

    # save the plot
    plt.savefig(f"{plot_output_dir}/{dataset}{inputs}_{metric}_plot.png")
    plt.clf()


def call_plot(
    dataset="ag_news",
    plot_prefix="",
    lr_solver_outputs: str = None,
    ft_full_outputs: str = None,
    ft_layers_outputs: str = None,
    ft_head_outputs: str = None,
    ft_adaptor_outputs: str = None,
    og_outputs: str = None,
    lr_tuned_outputs: str = None,
    lr_adaptor_outputs_corr: str = None,
    lr_adaptor_outputs_corr_repeat: str = None,
    lr_adaptor_outputs_noncorr: str = None,
    k_range: List[int] = [4, 8, 16, 32, 48, 64],
    solver_type: str = "saga",
    save_dir: str = "plots",
    penalty: str = "l2",
    y_lim: Tuple[float, float] = (0.5, 0.95),
):
    # seed -> epoch -> k -> dict([])
    if lr_solver_outputs is not None:
        lr_data = pd.read_pickle(lr_solver_outputs)
        lr_acc = compute_accuracy(lr_data)
    else:
        lr_acc = None

    #### FT Methods ####

    if ft_full_outputs:
        ft_data = pd.read_pickle(ft_full_outputs)
        ft_full_acc = compute_accuracy(ft_data)
    else:
        ft_full_acc = None

    if ft_layers_outputs:
        ft_data = pd.read_pickle(ft_layers_outputs)
        ft_layers_acc = compute_accuracy(ft_data)
    else:
        ft_layers_acc = None

    if ft_head_outputs:
        ft_data = pd.read_pickle(ft_head_outputs)
        ft_head_acc = compute_accuracy(ft_data)
    else:
        ft_head_acc = None

    if ft_adaptor_outputs:
        ft_data = pd.read_pickle(ft_adaptor_outputs)
        ft_adaptor_acc = compute_accuracy(ft_data)
    else:
        ft_adaptor_acc = None

    ### BASE Model ###

    if og_outputs:
        og_data = pd.read_pickle(og_outputs)
        og_acc = compute_accuracy(og_data)
    else:
        og_acc = None

    if lr_tuned_outputs:
        lr_tuned_data = pd.read_pickle(lr_tuned_outputs)
        lr_tuned_acc = compute_accuracy(lr_tuned_data)
    else:
        lr_tuned_acc = None

    ### ADAPTOR MODEL ###

    if lr_adaptor_outputs_corr:
        lr_adaptor_data = pd.read_pickle(lr_adaptor_outputs_corr)
        lr_adaptor_acc_corr = compute_accuracy(lr_adaptor_data)
    else:
        lr_adaptor_acc_corr = None

    if lr_adaptor_outputs_noncorr:
        lr_adaptor_data = pd.read_pickle(lr_adaptor_outputs_noncorr)
        lr_adaptor_acc_noncorr = compute_accuracy(lr_adaptor_data)
    else:
        lr_adaptor_acc_noncorr = None

    if lr_adaptor_outputs_corr_repeat:
        lr_adaptor_data = pd.read_pickle(lr_adaptor_outputs_corr_repeat)
        lr_adaptor_acc_corr_repeat = compute_accuracy(lr_adaptor_data)
    else:
        lr_adaptor_acc_corr_repeat = None

    breakpoint()
    plot_accuracy(
        lr_solver=lr_acc,
        ft_full_accuracies=ft_full_acc,
        ft_layers_accuracies=ft_layers_acc,
        ft_head_accuracies=ft_head_acc,
        ft_adaptor_accuracies=ft_adaptor_acc,
        tuned_accuracies=lr_tuned_acc,
        original_accuracies=og_acc,
        lr_adaptor_corr_accuracies=lr_adaptor_acc_corr,
        lr_adaptor_noncorr_accuracies=lr_adaptor_acc_noncorr,
        lr_adaptor_corr_repeat_accuracies=lr_adaptor_acc_corr_repeat,
        dataset=dataset,
        solver_type=solver_type,
        penalty=penalty,
        k_range=k_range,  # [4, 8, 16, 32, 48, 64],
        plot_prefix=plot_prefix,
        plot_output_dir=save_dir,
        y_lim=y_lim,
    )


def plot_ensemble(
    individual_outputs: List[torch.Tensor] = None,
    ensemble_outputs: torch.Tensor = None,
    k_range: List[int] = [4, 8, 16, 32, 48, 64],
    dataset: str = "SMS",
    metric: str = "ensemble",
    plot_output_dir: str = "plots",
):
    inputs = ""
    y_values = []
    colors = ["0.0", "0.2", "0.4", "0.6", "0.7"]
    colors = colors[::-1]
    if individual_outputs is not None:
        inputs += "FT"
        for i, outputs in enumerate(individual_outputs):
            outputs = torch.mean(outputs, dim=-1)
            for epoch in range(outputs.shape[0]):
                y_values.append(
                    (
                        outputs[epoch, :].tolist(),
                        f"LR Tuned Model (# {i+1}))",
                        colors[i],
                    )
                )
            # for each epoch, plot the accuracy for each k, and label each plot with the epoch, add color
    if ensemble_outputs is not None:
        inputs += "_LR"
        ensemble_outputs = torch.mean(ensemble_outputs, dim=-1)
        for epoch in range(ensemble_outputs.shape[0]):
            y_values.append(
                (ensemble_outputs[epoch, :].tolist(), f"Ensemble Model", "tab:blue")
            )

    for i, y in enumerate(y_values):
        plt.plot(k_range, y[0], label=y[1], color=y[2])

    plt.xlabel("Number of In-Context Examples")
    plt.title(f"{metric} for {dataset} (2-class) Dataset")
    plt.ylabel(f"{metric}")
    plt.legend()
    plt.show()

    # save the plot
    plt.savefig(f"{plot_output_dir}/{dataset}{inputs}_{metric}_ensemble_plot.png")
    plt.clf()


def plot_ensemble_accuracy():
    data_files = []
    individual_accuracies = []
    for rs in [0, 1, 2, 3, 4]:
        # path_to_file = f"eval_sms_runid_9aa9372d-384e-4bb1-b710-947c363d4ef5_ckpt_model_2000_rs_{rs}.pkl"
        path_to_file = f"eval_hate_runid_9aa9372d-384e-4bb1-b710-947c363d4ef5_ckpt_model_2000_rs_{rs}.pkl"
        data = pd.read_pickle(path_to_file)
        individual_accuracies.append(compute_accuracy(data))

        data_files.append(data)

    ensemble_accuracy = ensemble_accuracy(data_files)
    plot_ensemble(
        individual_accuracies, ensemble_accuracy, dataset="HATE", metric="Accuracy"
    )


ds = "sms"
model = "gpt-neo-125m"
# ft_outputs = f"EleutherAI-{model}_{ds}_eval_{ds}_.pkl"
ft_outputs = f"EleutherAI-{model}_train_{ds}_eval_{ds}_.pkl"
lr_tuned_outputs = f"eval_{ds}_runid_gpt-neo-125m_ag_news_s42_k128_e100_None_ckpt_pytorch_model_rs_9_mgpt-neo-125m_adaptor_ftmodel.pkl"
og_outputs = f"./eval_outputs/EleutherAI-gpt-neo-125M_eval_{ds}_rs_0.pkl"
lr_solver_outputs = f"LRSolver_{ds}_runid_gpt-neo-125m_ag_news_s42_k128_e100_None_ckpt_pytorch_model_rs_9_mgpt-neo-125m_adaptor_ftmodel.pkl"

# call_plot(
#     dataset=ds,
#     lr_solver_outputs=lr_solver_outputs,
#     ft_outputs=ft_outputs,
#     og_outputs=og_outputs,
#     lr_tuned_outputs=lr_tuned_outputs,
#     plot_prefix="ftmodel",
# )


# d32_nocorrupt
if __name__ == "__main__":
    # import argparse

    parser = argparse.ArgumentParser(description="Say hello")
    parser.add_argument(
        "--ftfull_outputs_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--ftlayers_outputs_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--fthead_outputs_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--ftadaptor_outputs_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--original_outputs_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--lr_solver_outputs_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--lr_adaptor_outputs_cor_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--lr_adaptor_outputs_noncor_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--lr_adaptor_outputs_cor_repeat_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--lr_tuned_path",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--solver_type",
        type=str,
        default="saga",
    )
    parser.add_argument(
        "--penalty",
        type=str,
        default="l2",
    )
    parser.add_argument(
        "--k_range",
        type=str,
        default="[4, 8, 16, 32, 48, 64, 96, 128]",
    )
    parser.add_argument(
        "--prefix",
        type=str,
        default="",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="./plots",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="SMS",
    )
    parser.add_argument(
        "--y_lim",
        type=str,
        default="(0.5, 1.0)",
    )

    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    call_plot(
        dataset=args.dataset,
        lr_solver_outputs=args.lr_solver_outputs_path,
        ft_full_outputs=args.ftfull_outputs_path,
        ft_layers_outputs=args.ftlayers_outputs_path,
        ft_head_outputs=args.fthead_outputs_path,
        ft_adaptor_outputs=args.ftadaptor_outputs_path,
        lr_tuned_outputs=args.lr_tuned_path,
        lr_adaptor_outputs_corr=args.lr_adaptor_outputs_cor_path,
        lr_adaptor_outputs_noncorr=args.lr_adaptor_outputs_noncor_path,
        lr_adaptor_outputs_corr_repeat=args.lr_adaptor_outputs_cor_repeat_path,
        og_outputs=args.original_outputs_path,
        plot_prefix=args.prefix,
        k_range=eval(args.k_range),
        solver_type=args.solver_type,
        penalty=args.penalty,
        save_dir=args.save_dir,
        y_lim=eval(args.y_lim),
    )
