"""
Given a directory of results, plot the benchmarks for each task as a bar chart and line chart.
"""

import argparse
import os
from typing import Optional

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from dgeb import TaskResult, get_all_tasks, get_output_folder, get_tasks_by_name

ALL_TASKS = [task.metadata.id for task in get_all_tasks()]


def plot_benchmarks(
    results_dir,
    task_ids: Optional[list[str]] = None,
    output="benchmarks.png",
    model_substring=None,
):
    models = os.listdir(results_dir)
    all_results = []
    tasks = get_all_tasks() if task_ids is None else get_tasks_by_name(task_ids)
    for model_name in models:
        if model_substring is not None and all(
            substr not in model_name for substr in model_substring
        ):
            continue

        for task in tasks:
            if task.metadata.display_name == "NoOp Task":
                continue
            filepath = get_output_folder(model_name, task, results_dir, create=False)
            # if the file does not exist, skip
            if not os.path.exists(filepath):
                continue

            with open(filepath) as f:
                task_result = TaskResult.model_validate_json(f.read())
            num_params = task_result.model["num_params"]
            primary_metric_id = task_result.task.primary_metric_id
            main_scores = [
                metric.value
                for layer_result in task_result.results
                for metric in layer_result.metrics
                if metric.id == primary_metric_id
            ]
            best_score = max(main_scores)
            all_results.append(
                {
                    "task": task.metadata.display_name,
                    "model": model_name,
                    "num_params": num_params,
                    "score": best_score,
                }
            )

    results_df = pd.DataFrame(all_results)
    # order the models by ascending number of parameters
    results_df["num_params"] = results_df["num_params"].astype(int)
    results_df = results_df.sort_values(by="num_params")
    # number of tasks
    n_tasks = len(set(results_df["task"]))

    _, ax = plt.subplots(2, n_tasks, figsize=(5 * n_tasks, 10))

    for i, task in enumerate(set(results_df["task"])):
        if n_tasks > 1:
            sns.barplot(
                x="model",
                y="score",
                data=results_df[results_df["task"] == task],
                ax=ax[0][i],
            )
            ax[0][i].set_title(task)
            # rotate the x axis labels
            for tick in ax[0][i].get_xticklabels():
                tick.set_rotation(90)
        else:
            sns.barplot(
                x="model",
                y="score",
                data=results_df[results_df["task"] == task],
                ax=ax[0],
            )
            ax[0].set_title(task)
            # rotate the x axis labels
            for tick in ax[0].get_xticklabels():
                tick.set_rotation(90)

    # make a line graph with number of parameters on x axis for each task in the second row of figures
    for i, task in enumerate(set(results_df["task"])):
        if n_tasks > 1:
            sns.lineplot(
                x="num_params",
                y="score",
                data=results_df[results_df["task"] == task],
                ax=ax[1][i],
            )
            ax[1][i].set_title(task)
            ax[1][i].set_xlabel("Number of parameters")
        else:
            sns.lineplot(
                x="num_params",
                y="score",
                data=results_df[results_df["task"] == task],
                ax=ax[1],
            )
            ax[1].set_title(task)
            ax[1].set_xlabel("Number of parameters")

    plt.tight_layout()
    plt.savefig(output)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-d",
        "--results_dir",
        type=str,
        default="results",
        help="Directory containing the results of the benchmarking",
    )
    parser.add_argument(
        "-t",
        "--tasks",
        type=lambda s: [item for item in s.split(",")],
        default=None,
        help=f"Comma separated list of tasks to plot. Choose from {ALL_TASKS} or do not specify to plot all tasks. ",
    )
    parser.add_argument(
        "-o",
        "--output",
        type=str,
        default="benchmarks.png",
        help="Output file for the plot",
    )
    parser.add_argument(
        "--model_substring",
        type=lambda s: [item for item in s.split(",")],
        default=None,
        help="Comma separated list of model substrings. Only plot results for models containing this substring",
    )
    args = parser.parse_args()

    plot_benchmarks(args.results_dir, args.tasks, args.output, args.model_substring)
