"""Process and plot evaluation results."""

from dataclasses import dataclass
import torch
import argparse
import textwrap
import numpy as np
import wandb as wb
import matplotlib.pyplot as plt

import traceback

import os
from tqdm import tqdm

from lm_eval.tasks import get_task
from utils import aggregate

from utils import PILE_WEIGHTS


#
#   Config
#


colors = {"neutral": "steelblue", "good": "seagreen", "bad": "firebrick"}

FONTSIZE = 8
ALPHA = 0.1

ACQUISITION_FUNCTIONS = [
    "Random",
    "Random-preselected",
    "ITL",
    "ITL-noiseless",
    "UncertaintySampling",
    "NearestNeighbour",
    "ONN",
    "VTL",
    "CTL",
]

ACQUISITION_FUNCTION_COLORS = {
    "Random": "gray",
    "Random-preselected" : "gray",
    "ITL": "orange",
    "ITL-noiseless": "purple",
    "UncertaintySampling": "red",
    "NearestNeighbour": "black",
    "ONN": "black",
    "VTL": "green",
    "CTL": "blue",
}

ACQUISITION_FUNCTION_STYLES = {
    "Random": "solid",
    "Random-preselected" : "dotted",
    "ITL": "solid",
    "ITL-noiseless": "solid",
    "UncertaintySampling": "solid",
    "NearestNeighbour": "solid",
    "ONN": "dotted",
    "VTL": "solid",
    "CTL": "dashed",
}

ACQUISITION_FUNCTION_TITLES = {
    "Random": "Random",
    "Random-preselected" : "Random preselected",
    "ITL": "ITL",
    "ITL-noiseless": "ITL noiseless",
    "UncertaintySampling": "Uncertainty Sampling",
    "NearestNeighbour": "Nearest Neighbour",
    "ONN": "Nearest Neighbour only",
    "VTL": "VTL",
    "CTL": "CTL",
}

TITLES = {
    "bits_per_byte": "Bits per byte",
    "byte_perplexity": "Byte perplexity",
    "word_perplexity": "Word perplexity",
    "perplexity": "Perplexity",
}

MODEL = {
    "gpt2": "gpt2",
    "gpt2-large": "gpt2-large",
    "gptneo": "gptneo",
    "phi3": "microsoft/Phi-3-mini-4k-instruct",
}

MODEL_TITLES = {
    "gpt2": "GPT 2",
    "gpt2-large": "GPT large",
    "gptneo": "GPT Neo",
    "microsoft/Phi-3-mini-4k-instruct": "Phi-3",
}

DATASETS = [
    "pile_github",
    "pile_stackexchange",
    "pile_uspto",
    "pile_wikipedia",
    "pile_dm-mathematics",
    "pile_enron",
]

DATASET_TITLES = {
    "pile_github": "GitHub",
    "pile_stackexchange": "",
    "pile_uspto": "USPTO",
    "pile_wikipedia": "Wikipedia",
    "pile_dm-mathematics": "",
    "pile_enron": "",
}


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=textwrap.dedent(
            """\
            Required Arguments are

                --plot  all
                        k_mult
                        gradient
                        batched
                        noise
                        points
                        costs
                        log-costs
                        dot-products

                --model gpt2
                        gpt2-large
                        gptneo
                        phi3

                --name  Name of the experiment defined in the launcher
            """
        ),
    )
    parser.add_argument("--results_dir", type=str, default="results/")
    parser.add_argument("--name", type=str, default="LLM")
    parser.add_argument("--model", type=str, default="gpt2")
    parser.add_argument("--world_size", type=int, default=32)
    parser.add_argument("--bootstrap", action="store_true")
    parser.add_argument("--error_bars", action="store_true")
    parser.add_argument("--plot", type=str, default="all")
    return parser.parse_args()


#
#   Wrapper types
#


@dataclass
class Config:
    model: str = "gpt2"
    num_neighbors: list[int] = 50
    k_mult: list[int] = 4
    gradient_steps: list[int] = 1
    batch_size: list[int] = 1
    noise: list[float] = 1.0


class Task:
    def __init__(
        self,
        dataset: str,
        acquisition_function: str,
        num_neighbors: int,
        k_mult,
        gradient_steps: int,
        batch_size: int,
        noise: float,
    ):
        self.dataset = dataset
        self.acquisition_function = acquisition_function
        self.num_neighbors = num_neighbors
        self.k_mult = k_mult
        self.gradient_steps = gradient_steps
        self.batch_size = batch_size
        self.noise = noise

    def __str__(self):
        return "%s_%s_n%d_k%d_g%d_b%d_l%f" % (
            self.dataset,
            self.acquisition_function,
            self.num_neighbors,
            self.k_mult,
            self.gradient_steps,
            self.batch_size,
            self.noise,
        )


def get_tasks(config, datasets, acquisition_functions):
    tasks = []

    for dataset in datasets:
        for algo in acquisition_functions:
            for n in config.num_neighbors:
                for k in config.k_mult:
                    for g in config.gradient_steps:
                        for b in config.batch_size:
                            for l in config.noise:
                                tasks.append(Task(dataset, algo, n, k, g, b, l))

    return tasks


#
#   Aggregate stats
#


def pile_all_error_bars(results_dir):
    """Compute weighted error bars for pile_all task."""
    task_names = list(PILE_WEIGHTS.keys())
    befores = []
    afters = []
    for task_name in task_names:
        before_err, after_err = pile_task_error_bars(results_dir, task_name)
        befores.append(before_err)
        afters.append(after_err)
    task_weights = [PILE_WEIGHTS[x] / 100 for x in task_names]
    before_err = np.average(befores, axis=0, weights=task_weights)
    after_err = np.average(afters, axis=0, weights=task_weights)
    return before_err, after_err


def pile_task_error_bars(results_dir, task_name):
    """Load error bars for basic pile task."""
    bootstrap = torch.load("%s/bootstrap_%s.pth" % (results_dir, task_name))
    bootstrap_before = bootstrap[0]
    bootstrap_after = bootstrap[-1]
    before_err = np.quantile(bootstrap_before, [0.1, 0.9])
    after_err = np.quantile(bootstrap_after, [0.1, 0.9])
    return before_err, after_err


def load_bootstrap_error_bars(results_dir, task_name):
    if task_name == "pile_all":
        return pile_all_error_bars(results_dir)
    else:
        return pile_task_error_bars(results_dir, task_name)


def compute_pile_all(aggregate_stats, metric="bits_per_byte"):
    """Compute weighted aggregate for synthetic pile_all task."""
    task_values = []
    for task in aggregate_stats:
        task_stats = aggregate_stats[task]
        values = [metrics[metric] for metrics in task_stats]
        values = [
            v * PILE_WEIGHTS["_".join(task.split("_")[0:2])] / 100 for v in values
        ]
        task_values.append(values)
    median = int(np.median([len(v) for v in task_values]))
    task_values = [v for v in task_values if len(v) == median]
    task_values = np.array(task_values)
    return np.sum(task_values, axis=0), np.std(task_values, axis=0)


def bootstrap(data, task_name: Task, n_resamples=1000):
    """Compute bootstrap of aggregate statistics"""
    n = len(data)
    task = get_task(task_name.dataset)(download=False)
    statistics = []
    for _ in tqdm(range(n_resamples)):
        sample = data[np.random.randint(0, n, n)]
        statistics.append(aggregate(sample, task)["bits_per_byte"])
    return np.array(statistics)


def compute_bootstrap_error_bars(results_dir, metrics, task_name: Task):
    """Compute error bars using bootstrap."""
    before = metrics[:, 0]
    after = metrics[:, -1]
    bootstrap_before = bootstrap(before, task_name)
    bootstrap_after = bootstrap(after, task_name)
    bootstrap_values = [bootstrap_before, bootstrap_after]
    torch.save(bootstrap_values, "%s/bootstrap_%s.pth" % (results_dir, task_name))


def compute_aggregate_stats(
    datasets,
    acquisition_functions,
    results_dir,
    num_neighbors,
    k_mult,
    gradient_steps,
    batch_size,
    noise,
    world_size,
    bootstrap=False,
):
    """Load results from file and aggregate them."""

    if os.path.exists("%s/aggregate_stats.pth" % (results_dir)):
        return torch.load("%s/aggregate_stats.pth" % (results_dir))
    aggregate_stats = {}
    for dataset in datasets:
        data = get_task(dataset)(download=False)
        for algo in acquisition_functions:
            for n in num_neighbors:
                for k in k_mult:
                    for g in gradient_steps:
                        for b in batch_size:
                            for l in noise:
                                metrics = []
                                losses = []
                                training_costs = []
                                retrieval_costs = []
                                for rank in tqdm(range(world_size)):
                                    results_file = (
                                        "%s/%s_%s_n%d_k%d_g%d_b%d_l%f_r%d.pth"
                                        % (
                                            results_dir,
                                            dataset,
                                            algo,
                                            n,
                                            k,
                                            g,
                                            b,
                                            l,
                                            rank,
                                        )
                                    )
                                    if os.path.exists(results_file):
                                        try:
                                            results = torch.load(results_file)
                                            assert len(results) == 4
                                            # length of results[0] is the number of eval points
                                            metrics += results[0]
                                            losses += results[1]
                                            training_costs += results[2]
                                            retrieval_costs += results[3]
                                        except Exception as e:
                                            print("Error loading %s" % (results_file))
                                            print(e)
                                            continue

                                # Length of all_stats is the number of points evaluated
                                # Each element is a list length the number of training steps
                                # Filter out the runs that didn't finish

                                task_name = "%s_%s_n%d_k%d_g%d_b%d_l%f" % (
                                    dataset,
                                    algo,
                                    n,
                                    k,
                                    g,
                                    b,
                                    l,
                                )
                                if len(metrics) == 0:
                                    print("No results for %s" % (task_name))
                                    continue
                                if len(losses) == 0:
                                    print("No losses for %s" % (task_name))
                                    continue
                                if len(training_costs) == 0:
                                    print("No training costs for %s" % (task_name))
                                    continue
                                median = int(np.median([len(m) for m in metrics]))
                                metrics = [
                                    metrics[i]
                                    for i in range(len(metrics))
                                    if len(metrics[i]) == median
                                ]
                                median = int(np.median([len(l) for l in losses]))
                                losses = [
                                    losses[i]
                                    for i in range(len(losses))
                                    if len(losses[i]) == median
                                ]
                                median = int(
                                    np.median([len(l) for l in training_costs])
                                )
                                training_costs = [
                                    training_costs[i]
                                    for i in range(len(training_costs))
                                    if len(training_costs[i]) == median
                                ]

                                # 2d array (num_points, num_steps) where each entry is a dict of metrics
                                metrics = np.array(metrics)
                                if bootstrap:
                                    compute_bootstrap_error_bars(
                                        results_dir, metrics, task_name
                                    )

                                losses = np.array(losses)
                                training_costs = np.array(training_costs)
                                print(task_name)
                                print(metrics.shape)
                                print(losses.shape)
                                print(training_costs.shape)
                                retrieval_costs = np.tile(
                                    np.array([retrieval_costs]).transpose(),
                                    (1, training_costs.shape[1]),
                                )
                                print(retrieval_costs.shape)
                                task_stats = []
                                try:
                                    for j in range(metrics.shape[1]):
                                        task_stats.append(
                                            aggregate(metrics[:, j], data)
                                        )
                                    # No loss record before training
                                    task_stats[0]["training_loss"] = np.nan
                                    task_stats[0]["training_err"] = np.nan
                                    task_stats[0]["training_cost"] = np.nan
                                    task_stats[0]["retrieval_cost"] = np.nan
                                    for j in range(losses.shape[1]):
                                        task_stats[j + 1]["training_loss"] = np.nanmean(
                                            losses[:, j]
                                        )
                                    for j in range(losses.shape[1]):
                                        task_stats[j + 1]["training_err"] = (
                                            np.nanstd(losses[:, j]) / losses.shape[1]
                                        )
                                    for j in range(training_costs.shape[1]):
                                        task_stats[j + 1]["training_cost"] = np.nanmean(
                                            training_costs[:, j]
                                        )
                                    for j in range(retrieval_costs.shape[1]):
                                        task_stats[j + 1]["retrieval_cost"] = (
                                            np.nanmean(retrieval_costs[:, j])
                                        )
                                except:
                                    print(traceback.format_exc())
                                    print("Failed on %s" % (task_name))

                                aggregate_stats[task_name] = task_stats

    pile_all_bpb, pile_all_err = compute_pile_all(
        aggregate_stats, metric="bits_per_byte"
    )
    pile_all_tl = compute_pile_all(aggregate_stats, metric="training_loss")
    pile_all_tc = compute_pile_all(aggregate_stats, metric="training_cost")
    pile_all_rc = compute_pile_all(aggregate_stats, metric="retrieval_cost")
    pile_all = [
        {
            "bits_per_byte": bpb,
            "bits_per_byte_err": err,
            "training_loss": tl,
            "training_cost": tc,
            "retrieval_cost": rc,
        }
        for bpb, err, tl, tc, rc in zip(
            pile_all_bpb, pile_all_err, pile_all_tl, pile_all_tc, pile_all_rc
        )
    ]
    aggregate_stats["pile_all"] = pile_all

    # pickle the results
    torch.save(aggregate_stats, "%s/aggregate_stats.pth" % (results_dir))

    return aggregate_stats


#
#   Plots
#


def plot(
    results_dir,
    aggregate_stats,
    metric_name,
    config,
    datasets=DATASETS,
    acquisition_functions=ACQUISITION_FUNCTIONS,
    error_bars=False,
):
    """Plot before-after bar chart for all tasks."""
    # plot 8 x 4 grid of figures
    fig, axs = plt.subplots(
        len(acquisition_functions), 6, figsize=(10, 2 * len(acquisition_functions))
    )
    plt.rcParams.update({"font.size": FONTSIZE})
    for i, algo in enumerate(acquisition_functions):
        for dataset, ax in list(zip(datasets, axs[i].flatten())):
            try:
                task = Task(
                    dataset,
                    algo,
                    config.num_neighbors[0],
                    config.k_mult[0],
                    config.gradient_steps[0],
                    config.batch_size[0],
                    config.noise[0],
                )
                metrics_before = aggregate_stats[str(task)][0]
                metrics_after = aggregate_stats[str(task)][-1]
                ax.tick_params(axis="both", which="major", labelsize=8)
                ax.set_title(dataset[5:])
                if metrics_after[metric_name] > metrics_before[metric_name]:
                    color = colors["bad"]
                else:
                    color = colors["good"]
                before = metrics_before[metric_name]
                after = metrics_after[metric_name]
                print(dataset, before, after)
                if error_bars:
                    before_err, after_err = load_bootstrap_error_bars(
                        results_dir, dataset
                    )
                    ax.bar(
                        ["before", "after"],
                        [before, after],
                        yerr=[
                            [before - before_err[0], after - after_err[0]],
                            [before_err[1] - before, after_err[1] - after],
                        ],
                        color=[colors["neutral"], color],
                        capsize=5,
                    )
                else:
                    ax.bar(
                        ["before", "after"],
                        [before, after],
                        color=[colors["neutral"], color],
                    )
                    ax.text(
                        1,
                        metrics_after[metric_name],
                        "%.0f %%"
                        % (
                            100
                            * metrics_after[metric_name]
                            / metrics_before[metric_name]
                        ),
                        ha="center",
                        va="bottom",
                    )
                ax.set_ylim([0, np.max([before * 1.15, after * 1.15])])
            except KeyError:
                continue

        axs[i][0].set_ylabel(
            ACQUISITION_FUNCTION_TITLES[algo],
            rotation=90,
            size="large",
            ha="center",
            va="center",
            labelpad=5,
        )

    fig.suptitle(
        TITLES[metric_name] + ", " + MODEL_TITLES[config.model], fontsize=FONTSIZE + 2
    )
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    if error_bars:
        plt.savefig(
            "%s/images/before-after-%s-error-bars.pdf" % (results_dir, metric_name)
        )
    else:
        plt.savefig("%s/images/before-after-%s.pdf" % (results_dir, metric_name))
    plt.close()


def plot_gradient(
    results_dir,
    aggregate_stats,
    metric_name,
    config,
    datasets=DATASETS,
    acquisition_functions=ACQUISITION_FUNCTIONS,
    error_bars=False,
):
    """Plot curves for top 6 tasks."""
    plt.rcParams.update({"font.size": FONTSIZE + 4})
    fig, axs = plt.subplots(1, 2, figsize=(10, 6))
    for dataset, ax in list(zip(datasets, axs.flatten())):
        for algo in acquisition_functions:
            xs = config.gradient_steps
            ys = []
            err = []
            for i in range(4):
                task = Task(
                    dataset,
                    algo,
                    config.num_neighbors[0],
                    config.k_mult[0],
                    xs[i],
                    config.batch_size[0],
                    config.noise[0],
                )
                ys.append(aggregate_stats[i][str(task)][-1][metric_name])
                err.append(aggregate_stats[i][str(task)][-1][metric_name + "_err"])
            if dataset == datasets[0]:
                ax.set_ylabel(metric_name)
            ax.tick_params(axis="both", which="major", labelsize=8)
            ax.plot(
                xs,
                ys,
                linewidth=1,
                marker=".",
                color=ACQUISITION_FUNCTION_COLORS[algo],
                label=ACQUISITION_FUNCTION_TITLES[algo],
                linestyle=ACQUISITION_FUNCTION_STYLES[algo],
            )
            ax.fill_between(
                xs,
                [y - e for y, e in zip(ys, err)],
                [y + e for y, e in zip(ys, err)],
                color=ACQUISITION_FUNCTION_COLORS[algo],
                alpha=ALPHA,
            )
            ax.title.set_fontsize(10)
            ax.title.set_text(dataset[5:])
            ax.set_xticks([1, 2, 5, 10])
            ax.set_xlabel("gradient-steps")

    plt.subplots_adjust(bottom=0.4)
    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        bbox_to_anchor=(0.5, -0.01),
        ncol=4,
        fontsize=FONTSIZE + 4,
        bbox_transform=fig.transFigure,
    )
    fig.suptitle(
        TITLES[metric_name] + ", " + MODEL_TITLES[config.model], fontsize=FONTSIZE + 6
    )
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    plt.savefig("%s/images/%s.pdf" % (results_dir, metric_name))
    plt.close()


def plot_batched(
    results_dir,
    aggregate_stats,
    metric_name,
    config,
    datasets=DATASETS,
    acquisition_functions=ACQUISITION_FUNCTIONS,
    error_bars=False,
):
    """Plot curves for top 6 tasks."""
    plt.rcParams.update({"font.size": FONTSIZE + 4})
    fig, axs = plt.subplots(1, 2, figsize=(10, 6))
    for dataset, ax in list(zip(datasets, axs.flatten())):
        for algo in acquisition_functions:
            xs = config.batch_size
            ys = []
            err = []
            for i in range(len(xs)):
                task = Task(
                    dataset,
                    algo,
                    config.num_neighbors[0],
                    config.k_mult[0],
                    config.gradient_steps[0],
                    xs[i],
                    config.noise[0],
                )
                ys.append(aggregate_stats[i][str(task)][-1][metric_name])
                err.append(aggregate_stats[i][str(task)][-1][metric_name + "_err"])
            if dataset == datasets[0]:
                ax.set_ylabel(metric_name)
            ax.tick_params(axis="both", which="major", labelsize=8)
            ax.plot(
                xs,
                ys,
                linewidth=1,
                marker=".",
                color=ACQUISITION_FUNCTION_COLORS[algo],
                label=ACQUISITION_FUNCTION_TITLES[algo],
                linestyle=ACQUISITION_FUNCTION_STYLES[algo],
            )
            ax.fill_between(
                xs,
                [y - e for y, e in zip(ys, err)],
                [y + e for y, e in zip(ys, err)],
                color=ACQUISITION_FUNCTION_COLORS[algo],
                alpha=ALPHA,
            )
            ax.title.set_fontsize(10)
            ax.title.set_text(dataset[5:])
            ax.set_xticks([1, 10])
            ax.set_xlabel("gradient-steps")

    plt.subplots_adjust(bottom=0.4)
    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        bbox_to_anchor=(0.5, -0.01),
        ncol=4,
        fontsize=FONTSIZE + 4,
        bbox_transform=fig.transFigure,
    )
    fig.suptitle(
        TITLES[metric_name] + ", " + MODEL_TITLES[config.model], fontsize=FONTSIZE + 6
    )
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    plt.savefig("%s/images/%s.pdf" % (results_dir, metric_name))
    plt.close()


def plot_k(
    results_dir,
    aggregate_stats,
    metric_name,
    config,
    datasets=DATASETS,
    acquisition_functions=ACQUISITION_FUNCTIONS,
    error_bars=False,
):
    """Plot curves for top 6 tasks."""
    plt.rcParams.update({"font.size": FONTSIZE + 4})
    fig, axs = plt.subplots(1, 3, figsize=(10, 6))
    for dataset, ax in list(zip(datasets, axs.flatten())):
        for algo in acquisition_functions:
            xs = config.k_mult
            ys = []
            err = []
            for i in range(5):
                task = Task(
                    dataset,
                    algo,
                    config.num_neighbors[0],
                    xs[i],
                    config.gradient_steps[0],
                    config.batch_size[0],
                    config.noise[0],
                )
                ys.append(aggregate_stats[i][str(task)][-1][metric_name])
                err.append(aggregate_stats[i][str(task)][-1][metric_name + "_err"])
            if dataset == datasets[0]:
                ax.set_ylabel(metric_name)
            ax.tick_params(axis="both", which="major", labelsize=8)
            ax.plot(
                xs,
                ys,
                linewidth=1,
                marker=".",
                color=ACQUISITION_FUNCTION_COLORS[algo],
                label=ACQUISITION_FUNCTION_TITLES[algo],
                linestyle=ACQUISITION_FUNCTION_STYLES[algo],
            )
            ax.fill_between(
                xs,
                [y - e for y, e in zip(ys, err)],
                [y + e for y, e in zip(ys, err)],
                color=ACQUISITION_FUNCTION_COLORS[algo],
                alpha=ALPHA,
            )
            ax.title.set_fontsize(10)
            ax.title.set_text(dataset[5:])
            ax.set_xticks(range(2, 11, 2))
            ax.set_xlabel("k-multiplier")

    plt.subplots_adjust(bottom=0.4)
    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        bbox_to_anchor=(0.5, -0.01),
        ncol=4,
        fontsize=FONTSIZE + 4,
        bbox_transform=fig.transFigure,
    )
    fig.suptitle(
        TITLES[metric_name] + ", " + MODEL_TITLES[config.model], fontsize=FONTSIZE + 6
    )
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    plt.savefig("%s/images/%s.pdf" % (results_dir, metric_name))
    plt.close()


def plot_curve(
    results_dir,
    aggregate_stats,
    metric_name,
    config,
    datasets=DATASETS,
    acquisition_functions=ACQUISITION_FUNCTIONS,
    neighbors=50,
):
    """Plot curves for top 6 tasks."""
    plt.rcParams.update({"font.size": FONTSIZE})
    fig, axs = plt.subplots(2, 3, figsize=(10, 6))
    for algo in acquisition_functions:
        for dataset, ax in list(zip(datasets, axs.flatten())):
            try:
                task = Task(
                    dataset,
                    algo,
                    config.num_neighbors[0],
                    config.k_mult[0],
                    config.gradient_steps[0],
                    config.batch_size[0],
                    config.noise[0],
                )
                xs = list(range(len(aggregate_stats[str(task)])))
                ys = [aggregate_stats[str(task)][x][metric_name] for x in xs]
                err = [aggregate_stats[str(task)][x][metric_name + "_err"] for x in xs]
                if dataset == datasets[0]:
                    ax.set_ylabel(metric_name)
                ax.tick_params(axis="both", which="major", labelsize=8)
                ax.plot(
                    xs,
                    ys,
                    linewidth=1,
                    color=ACQUISITION_FUNCTION_COLORS[algo],
                    label=ACQUISITION_FUNCTION_TITLES[algo],
                    linestyle=ACQUISITION_FUNCTION_STYLES[algo],
                )
                ax.fill_between(
                    xs,
                    [y - e for y, e in zip(ys, err)],
                    [y + e for y, e in zip(ys, err)],
                    color=ACQUISITION_FUNCTION_COLORS[algo],
                    alpha=ALPHA,
                )
                ax.title.set_fontsize(10)
                ax.title.set_text(dataset[5:])
                ax.set_xticks(range(0, neighbors + 1, 25))
            except KeyError:
                print("No data for " + str(task))
                continue
    # Display xlabel only in the bottom row
    for i in range(3):
        axs[-1][i].set_xlabel("neighbours")

    plt.subplots_adjust(bottom=0.4)
    handles, labels = axs[0][0].get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        bbox_to_anchor=(0.5, 0.01),
        ncol=4,
        fontsize=FONTSIZE,
        bbox_transform=fig.transFigure,
    )
    fig.suptitle(
        TITLES[metric_name] + ", " + MODEL_TITLES[config.model], fontsize=FONTSIZE + 2
    )
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    plt.savefig("%s/images/%s.pdf" % (results_dir, metric_name))
    plt.close()


def plot_curve_combined(
    results_dir,
    aggregate_stats,
    metric_name,
    config,
    datasets=DATASETS,
    acquisition_functions=ACQUISITION_FUNCTIONS,
    neighbors=50,
):
    """Plot curves for top 6 tasks."""
    plt.rcParams.update({"font.size": FONTSIZE})
    fig, axs = plt.subplots(1, 1, figsize=(10, 6))
    for algo in acquisition_functions:
        task = Task(
            datasets[0],
            algo,
            config.num_neighbors[0],
            config.k_mult[0],
            config.gradient_steps[0],
            config.batch_size[0],
            config.noise[0],
        )
        xs = list(range(len(aggregate_stats[str(task)])))
        ys, err = [], []
        for dataset in datasets:
            task = Task(
                dataset,
                algo,
                config.num_neighbors[0],
                config.k_mult[0],
                config.gradient_steps[0],
                config.batch_size[0],
                config.noise[0],
            )
            ys.append([aggregate_stats[str(task)][x][metric_name] for x in xs])
            err.append(
                [aggregate_stats[str(task)][x][metric_name + "_err"] for x in xs]
            )
        ys = np.mean(ys, axis=0)
        err = np.mean(err, axis=0)
        axs.tick_params(axis="both", which="major", labelsize=8)
        axs.plot(
            xs,
            ys,
            linewidth=1,
            color=ACQUISITION_FUNCTION_COLORS[algo],
            label=ACQUISITION_FUNCTION_TITLES[algo],
            linestyle=ACQUISITION_FUNCTION_STYLES[algo],
        )
        axs.fill_between(
            xs,
            [y - e for y, e in zip(ys, err)],
            [y + e for y, e in zip(ys, err)],
            color=ACQUISITION_FUNCTION_COLORS[algo],
            alpha=ALPHA,
        )
        axs.title.set_fontsize(10)
        axs.title.set_text("Average over all datasets")
        axs.set_xticks([0, 25, 50])
        axs.set_xlabel("neighbours")
        axs.set_ylabel(metric_name)

    plt.subplots_adjust(bottom=0.4)
    handles, labels = axs.get_legend_handles_labels()
    fig.legend(
        handles,
        labels,
        loc="lower center",
        bbox_to_anchor=(0.5, 0.01),
        ncol=4,
        fontsize=FONTSIZE,
        bbox_transform=fig.transFigure,
    )
    fig.suptitle(TITLES[metric_name] + ", " + MODEL_TITLES[config.model], fontsize=FONTSIZE + 2)
    plt.tight_layout(rect=[0, 0.08, 1, 1])
    plt.savefig("%s/images/%s_combined.pdf" % (results_dir, metric_name))
    plt.close()


def plot_retrieval_costs(
    results_dir,
    aggregate_stats,
    config,
    datasets=DATASETS,
    acquisition_functions=ACQUISITION_FUNCTIONS,
):
    """Plot retrieval costs for all acquisition functions."""
    plt.rcParams.update({"font.size": 12})
    algo_costs = []
    for algo in acquisition_functions:
        task_costs = []
        for dataset in datasets:
            try:
                task = Task(
                    dataset,
                    algo,
                    config.num_neighbors[0],
                    config.k_mult[0],
                    config.gradient_steps[0],
                    config.batch_size[0],
                    config.noise[0],
                )
                costs = [x["retrieval_cost"] for x in aggregate_stats[str(task)]]
                costs = [x for x in costs if not np.isnan(x)]
                task_costs += costs
            except KeyError:
                continue
        algo_costs.append(task_costs)
    plt.figure(figsize=((12, 4)))
    plt.ylabel("Retrieval time (sec)")
    plt.boxplot(algo_costs, labels=acquisition_functions)
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig("%s/images/retrieval-time.pdf" % (results_dir))
    plt.close()


def get_points(tasks: list[str]):
    api = wb.Api()
    runs = api.runs(path=f"USER/AFT of LLMs", filters={"tags": {"$in": ["k-mult"]}})

    total = {str(task): 0 for task in tasks}
    duplicates = {str(task): 0 for task in tasks}
    points = {str(task): [0] * 100 for task in tasks}

    for run in tqdm(runs):
        df_history = run.history(
            keys=[
                "acquisition_function",
                "dataset",
                "num_neighbors",
                "k_mult",
                "noise",
                "indices",
            ]
        )

        for idx, row in df_history.iterrows():
            task = "%s_%s_n%d_k%d_l%f" % (
                row["dataset"],
                row["acquisition_function"],
                row["num_neighbors"],
                row["k_mult"],
                row["noise"],
            )
            total[task] += 1
            indices = row["indices"]
            occurences = {i: (indices.count(i) - 1) for i in indices}
            duplicates[task] += sum(occurences.values()) / len(indices)

            for i in range(row["k_mult"]):
                points[task][i] += len([j for j in indices if j <= (i + 1) * 50]) / 50

    for task in tasks:
        if total[task] == 0:
            continue
        points[task] = [i / total[task] for i in points[task]]
        duplicates[task] /= total[task]

    return points, duplicates


def plot_duplicates(
    results_dir,
    duplicates,
    config: Config,
    datasets=DATASETS,
    acquisition_functions=ACQUISITION_FUNCTIONS,
):
    num_rows = len(acquisition_functions)
    num_cols = len(datasets)
    data_matrix = np.zeros((num_rows, num_cols))

    for i, algo in enumerate(acquisition_functions):
        for j, task_name in enumerate(datasets):
            key = f"{task_name}_{algo}"
            data_matrix[i, j] = duplicates.get(key, 0)
    data_matrix = np.round(data_matrix, 3)

    fig, ax = plt.subplots(figsize=(num_cols * 2, num_rows * 1))
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    ax.set_frame_on(False)

    table = plt.table(
        cellText=data_matrix.astype(str),
        rowLabels=acquisition_functions,
        colLabels=[DATASET_TITLES[d] for d in datasets],
        cellLoc="center",
        loc="center",
        bbox=[0, 0, 1, 1],
    )

    table.auto_set_font_size(False)
    table.set_fontsize(14)
    table.scale(1.2, 0.8)
    for key, cell in table.get_celld().items():
        cell.set_edgecolor("black")
        cell.set_height(0.12)
        cell.set_width(0.45)

    fig.suptitle("Percentage of duplicates", fontsize=16)
    plt.tight_layout()
    plt.savefig("%s/images/duplicates.pdf" % (results_dir), bbox_inches="tight")
    plt.close()


def get_dot_products():
    api = wb.Api()
    runs = api.runs(path=f"USER/LLM-ITL", filters={"tags": {"$in": ["dot_product"]}})

    for run in runs:
        df_history = run.history(keys=["dot_products"])

        mean_list = []
        min_list = []
        max_list = []
        for idx, row in df_history.iterrows():
            dp = row["dot_products"]
            mean_list.append(np.mean(dp))
            min_list.append(min(dp))
            max_list.append(max(dp))

        if len(mean_list) > 0:
            mean_val = np.mean(mean_list)
            min_val = min(min_list)
            max_val = max(max_list)

            print(
                "mean="
                + str(mean_val)
                + " min="
                + str(min_val)
                + " max="
                + str(max_val)
            )


if __name__ == "__main__":

    args = parse_args()

    #
    #   Config
    #

    config = Config(
        model=MODEL[args.model],
        num_neighbors=[50],
        k_mult=[4],
        gradient_steps=[1],
        batch_size=[1],
        noise=[1.0],
    )

    datasets = [
        "pile_github",
        "pile_stackexchange",
        "pile_uspto",
        "pile_wikipedia",
        "pile_dm-mathematics",
        "pile_enron",
    ]

    acquisition_functions = [
        # "Random",
        #"Random-preselected",
        #"UncertaintySampling",
        "NearestNeighbour",
        #"ONN",
        #"ITL",
        "VTL",
    ]

    #
    #   Plots
    #

    plt.rcParams.update({"font.size": 11})
    plt.rcParams.update({"font.family": "serif"})

    results_dir = os.path.join(args.results_dir, args.name, config.model)
    image_dir = os.path.join(results_dir, "images")
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    aggregate_stats = compute_aggregate_stats(
        DATASETS,
        ACQUISITION_FUNCTIONS,
        results_dir,
        config.num_neighbors,
        config.k_mult,
        config.gradient_steps,
        config.batch_size,
        config.noise,
        args.world_size,
        bootstrap=args.bootstrap,
    )

    if args.plot == "all":
        plot_curve(
            results_dir,
            aggregate_stats,
            metric_name="bits_per_byte",
            config=config,
            datasets=datasets,
            acquisition_functions=acquisition_functions,
        )
        plot_curve(
            results_dir,
            aggregate_stats,
            metric_name="byte_perplexity",
            config=config,
            datasets=datasets,
            acquisition_functions=acquisition_functions,
        )
        plot_curve(
            results_dir,
            aggregate_stats,
            metric_name="word_perplexity",
            config=config,
            datasets=datasets,
            acquisition_functions=acquisition_functions,
        )

        plot_curve_combined(
            results_dir,
            aggregate_stats,
            metric_name="bits_per_byte",
            config=config,
            datasets=datasets,
            acquisition_functions=acquisition_functions,
        )
        plot_curve_combined(
            results_dir,
            aggregate_stats,
            metric_name="byte_perplexity",
            config=config,
            datasets=datasets,
            acquisition_functions=acquisition_functions,
        )
        plot_curve_combined(
            results_dir,
            aggregate_stats,
            metric_name="word_perplexity",
            config=config,
            datasets=datasets,
            acquisition_functions=acquisition_functions,
        )

        plot(
            results_dir,
            aggregate_stats,
            metric_name="bits_per_byte",
            config=config,
            datasets=datasets,
            acquisition_functions=acquisition_functions,
            error_bars=args.error_bars,
        )
        plot(
            results_dir,
            aggregate_stats,
            metric_name="byte_perplexity",
            config=config,
            datasets=datasets,
            acquisition_functions=acquisition_functions,
            error_bars=args.error_bars,
        )
        plot(
            results_dir,
            aggregate_stats,
            metric_name="word_perplexity",
            config=config,
            datasets=datasets,
            acquisition_functions=acquisition_functions,
            error_bars=args.error_bars,
        )

    elif args.plot == "k_mult":
        datasets = ["pile_github", "pile_uspto", "pile_wikipedia"]

        acquisition_functions = [
            "Random",
            "NearestNeighbour",
            "VTL",
        ]

        plot_k(
            results_dir,
            aggregate_stats,
            "bits_per_byte",
            config,
            datasets,
            acquisition_functions,
            args.error_bars,
        )
        plot_k(
            results_dir,
            aggregate_stats,
            "byte_perplexity",
            config,
            datasets,
            acquisition_functions,
            args.error_bars,
        )
        plot_k(
            results_dir,
            aggregate_stats,
            "word_perplexity",
            config,
            datasets,
            acquisition_functions,
            args.error_bars,
        )

    elif args.plot == "gradient":
        datasets = [
            "pile_github",
            "pile_wikipedia",
        ]

        acquisition_functions = [
            "Random",
            "NearestNeighbour",
            "VTL",
        ]

        plot_gradient(
            results_dir,
            aggregate_stats,
            "bits_per_byte",
            config,
            datasets,
            acquisition_functions,
            args.error_bars,
        )
        plot_gradient(
            results_dir,
            aggregate_stats,
            "byte_perplexity",
            config,
            datasets,
            acquisition_functions,
            args.error_bars,
        )
        plot_gradient(
            results_dir,
            aggregate_stats,
            "word_perplexity",
            config,
            datasets,
            acquisition_functions,
            args.error_bars,
        )

    elif args.plot == "batched":
        datasets = [
            "pile_github",
            "pile_wikipedia",
        ]

        acquisition_functions = [
            "Random",
            "NearestNeighbour",
            "VTL",
        ]

        plot_batched(
            results_dir,
            aggregate_stats,
            "bits_per_byte",
            config,
            datasets,
            acquisition_functions,
            args.error_bars,
        )
        plot_batched(
            results_dir,
            aggregate_stats,
            "byte_perplexity",
            config,
            datasets,
            acquisition_functions,
            args.error_bars,
        )
        plot_batched(
            results_dir,
            aggregate_stats,
            "word_perplexity",
            config,
            datasets,
            acquisition_functions,
            args.error_bars,
        )

    elif args.plot == "noise":
        pass

    elif args.plot == "points":
        datasets = [
            "pile_github",
            "pile_stackexchange",
            "pile_uspto",
            "pile_wikipedia",
            "pile_dm-mathematics",
        ]

        acquisition_functions = [
            "VTL",
        ]

        tasks = [
            "%s_%s_n%d_k%d_l%f" % (dataset, acquisition_function, n, k, l)
            for dataset, acquisition_function, n, k, l in zip(
                datasets,
                acquisition_functions,
                config.num_neighbors,
                config.k_mult,
                config.noise,
            )
        ]
        points, duplicates = get_points(tasks)

        plot_duplicates(
            results_dir, duplicates, config, acquisition_functions, datasets
        )

    elif args.plot == "costs":
        plot_retrieval_costs(
            results_dir, aggregate_stats, config, DATASETS, ACQUISITION_FUNCTIONS
        )

    elif args.plot == "log-costs":
        plot_retrieval_costs(
            results_dir, aggregate_stats, config, DATASETS, ACQUISITION_FUNCTIONS
        )

    elif args.plot == "dotproduct":
        dotproducts = get_dot_products()
