"""
Aggregate MTEB main_score over many tasks and plot mean results per model·subset.

Folders can be:
  – the default “results/{model}/{task}/…json” produced by mteb’s CLI :contentReference[oaicite:0]{index=0}
  – any flat directory containing JSON files (the script tries hard to guess model & task).

Example result JSON format is documented in the MTEB repo :contentReference[oaicite:1]{index=1} and on the
leader-board discussion :contentReference[oaicite:2]{index=2}.

Written for Python ≥ 3.9, uses only std-lib + pandas + matplotlib.
"""

from __future__ import annotations
import argparse, json, re
from pathlib import Path
from typing import Iterable

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Optional, Tuple

import seaborn as sns

import os

# ---------------------------------- Global ---------------------------------- #

MARKERS = ["o", "X", "s", "^", "D", "v", "P", "*", "<", ">"]

VIDORE_TASKS = [
    "VidoreArxivQARetrieval",
    "VidoreDocVQARetrieval",
    "VidoreInfoVQARetrieval",
    "VidoreTabfquadRetrieval",
    "VidoreTatdqaRetrieval",
    "VidoreShiftProjectRetrieval",
    "VidoreSyntheticDocQAAIRetrieval",
    "VidoreSyntheticDocQAEnergyRetrieval",
    "VidoreSyntheticDocQAGovernmentReportsRetrieval",
    "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
]

VIDORE_2_TASKS = [
    "Vidore2ESGReportsRetrieval",
    "Vidore2EconomicsReportsRetrieval",
    "Vidore2BioMedicalLecturesRetrieval",
    "Vidore2ESGReportsHLRetrieval",
]

DOC_TASKS = VIDORE_TASKS + VIDORE_2_TASKS

MSCOCO_TASKS = [
    "MSCOCOT2IRetrieval",
    # "MSCOCOI2TRetrieval",
]

FLICKR_TASKS = [
    "Flickr30kT2IRetrieval",
    # "Flickr30kI2TRetrieval",
]

IM_CAPT_TASKS = MSCOCO_TASKS + FLICKR_TASKS

IMCLASS_LP_TASKS = [
    # "Caltech101",
    # "DTD",
    # "FER2013",
    # "EuroSAT",
    "StanfordCars",
    "Food101Classification",
    # "OxfordFlowersClassification",
    # "OxfordPets",
]

IM_CLASS_ZS_TASKS = [
    "Caltech101ZeroShot", 
    # "DTDZeroShot", 
    "FER2013ZeroShot", 
    # "EuroSATZeroShot"
]

IMCLASS_TASKS = IM_CLASS_ZS_TASKS #+ IMCLASS_LP_TASKS
# IMCLASS_TASKS = IMCLASS_LP_TASKS

BENCHMARK_GROUPS = {
    "Document Retrieval": DOC_TASKS,
    "Caption Retrieval": IM_CAPT_TASKS,
    "Zero-Shot Classification": IM_CLASS_ZS_TASKS,
    "Linear Probe Classification": IMCLASS_LP_TASKS,
}

BENCHMARK_GROUPS_NO_LP = {
    "Document Retrieval": DOC_TASKS,
    "Caption Retrieval": IM_CAPT_TASKS,
    "Zero-Shot Classification": IM_CLASS_ZS_TASKS,
}

BENCHMARK_GROUPS_NO_IMCLASS = {
    "Document Retrieval": DOC_TASKS,
    "Caption Retrieval": IM_CAPT_TASKS,
}

BENCHMARK_GROUPS_FINEGRAIN = {
    "ViDoRe(v1)": VIDORE_TASKS,
    "ViDoRe(v2)": VIDORE_2_TASKS,
    "Average Doc.": DOC_TASKS,
    "MSCOCO": MSCOCO_TASKS,
    "Flickr30k": FLICKR_TASKS,
    "Average Capt.": IM_CAPT_TASKS,
    **{k: [k] for k in IMCLASS_TASKS},
    "Average Class.": IMCLASS_TASKS,
}

BENCHMARK_GROUPS_FINEGRAIN_NO_IMCLASS = {
    "ViDoRe(v1)": VIDORE_TASKS,
    "ViDoRe(v2)": VIDORE_2_TASKS,
    "Average Doc.": DOC_TASKS,
    "MSCOCO": MSCOCO_TASKS,
    "Flickr30k": FLICKR_TASKS,
    "Average Capt.": IM_CAPT_TASKS,
    # **{k: [k] for k in IMCLASS_TASKS},
    # "Average Class.": IMCLASS_TASKS,
}


# ───────────────────────── helpers ────────────────────────────

def optimal_subplots_shape(subplots, max_columns=5):
    if subplots == 1:
        return 1,1
    for k in range(max_columns, 2, -1):
        if subplots % k == 0:
            return subplots//k,k
    return 1,subplots

def map_task_to_group(task):
    if task in VIDORE_TASKS:
        return "ViDoRe(v1)"
    elif task in VIDORE_2_TASKS:
        return "ViDoRe(v2)"
    elif task in MSCOCO_TASKS:
        return "MSCOCO"
    elif task in FLICKR_TASKS:
        return "Flickr30k"
    else:
        return "Others"
    
def add_baselines(
    ax: plt.Axes,
    baseline_df: pd.DataFrame,
    columns: List[str],
    name_col: str = "model",
) -> Tuple[plt.Figure, plt.Axes]:
    """ Plot horizontal lines for each baseline in `baseline_df` on the given Axes. """
    for i, brow in baseline_df.iterrows():
        ax.axhline(brow[columns].mean(), linestyle="--", linewidth=3, alpha=0.85, label=brow[name_col])
    return ax

def _collect_results(
        paths: Iterable[Path], 
        model: str,
        langs: Iterable[str] | None = None,
    ) -> pd.DataFrame:
    """Read all JSON files and return a dataframe with model / subset / task / score."""
    rows = []
    for p in paths:
        try:
            blob: dict = json.loads(p.read_text())
            task = blob.get("task_name", "unknown")
            splits_scores = blob.get("scores", {})
            for split, scores in splits_scores.items():  
                for entry in scores:
                    lang = entry.get("languages", [None])[0]
                    if langs is not None and lang is not None and lang not in langs:
                        continue
                    rows.append(
                        dict(
                            model=model,
                            task=task,
                            subset=entry.pop("hf_subset", "all"),
                            language=lang,
                            **entry,
                        )
                    )
        except Exception as err:
            print(f"[warn] {p} skipped → {err}")
    if not rows:
        raise RuntimeError("No valid MTEB result files found.")
    return rows

def gather_results(
    results_root: Path,
    models: Iterable[str] | None = None,
    tasks: Iterable[str] | None = None,
    langs: Iterable[str] | None = None,
) -> pd.DataFrame:
    """
    Collect results from JSON files and return a DataFrame with aggregated scores.
    """
    path = Path(results_root)

    if models is None:
        models_paths = None
    else:
        models_paths = [path / "__".join(model.split("/")) for model in models]

    # gather all JSON paths
    files = []
    for p in os.listdir(path):
        p = path / p
        if models_paths is not None and not p in models_paths:
            continue
        if p.is_dir():
            # take latest written subfolder
            subfolders = sorted(p.iterdir(), key=lambda x: x.stat().st_mtime, reverse=True)[0]
            files.append(subfolders.rglob("*.json"))
    if not files:
        print(f"No files found for {models}")

    model_results = []
    models_metadata = []

    for model_files in files:
        model_files = list(model_files)  # convert generator to list

        # split files, if file is model_meta.json, then it is not a result file
        meta_file = [f for f in model_files if f.name == "model_meta.json"][0]
        result_files = [f for f in model_files if not "model_meta.json" in f.name]

        if tasks is not None:
            result_files = [f for f in result_files if any(f.name == f"{t}.json" for t in tasks)]

        if not result_files:
            print(f"[warn] No result files found for model {meta_file.parent.parent.name}. Skipping.")
            continue

        model_meta = json.loads(meta_file.read_text())
        model_name = model_meta.get("name", "unknown")

        models_metadata.append(model_meta)
        model_results.extend(_collect_results(result_files, model=model_name, langs=langs))

    # create DataFrame from collected results
    meta_df = pd.DataFrame(models_metadata).rename(columns={"name": "model"})
    results_df = pd.DataFrame(model_results)

    return meta_df, results_df

def create_leaderboard(
    df: pd.DataFrame, 
    benchmarks: Dict[str, list] = None, 
    align: str = "task",
    metric: str = "main_score",
) -> pd.DataFrame:
    """
    Create a leaderboard DataFrame from the given DataFrame filtered by benchmark.
    The DataFrame should contain columns: model, language, task, main_score.
    """

    # create a pivot table to aggregate scores
    pivot_df = (100 * df.pivot_table(
        index=["model"],
        columns=align,
        values=metric,
        aggfunc="mean"
    )).reset_index()

    nan_dict = {
        row["model"]: pivot_df.columns[row.isna()].tolist()
        for _, row in pivot_df.iterrows() if len(pivot_df.columns[row.isna()].tolist()) > 0
    }

    if nan_dict:
        print("Warning: some results are missing in the leaderboard.")
        for model, nan_cols in nan_dict.items():
            if nan_cols:
                print(f"{model}: {nan_cols}")

    if benchmarks is not None:
        for b, b_tasks in benchmarks.items():
            pivot_df[b] = pivot_df[b_tasks].mean(axis=1)
        pivot_df = pivot_df[["model"] + list(benchmarks.keys())]
        
    # add the mean score across all tasks
    pivot_df["mean_score"] = pivot_df.drop(columns=["model"]).mean(axis=1)

    # set the mean score as the first column after model and language
    pivot_df = pivot_df[["model", "mean_score"] + [col for col in pivot_df.columns if col not in ["model", "mean_score"]]]

    return pivot_df.sort_values(by=["mean_score"], ascending=False)


def fit_powerlaw(
    y: np.ndarray,
    steps: np.ndarray,
    n_interp: int = 300,
    log_x: bool = False,
) -> Tuple[np.ndarray, Tuple[np.ndarray, float, float]]:
    """ Fit a power-law model to the given data and return the fitted curve. """
    if steps.size < 2:
        raise ValueError("Need at least two distinct steps to fit a curve.")

    # Helper to fit power-law y = a * x^b using log-log linear regression
    def _fit_powerlaw(x, y):
        b, log_a = np.polyfit(np.log(x), np.log(y), 1)  # slope, intercept
        a = np.exp(log_a)
        return a, b

    # Fit power-law models
    a, b = _fit_powerlaw(steps, y)

    # Build dense x-grid for plotting the fitted curve
    step_min = np.min(steps[steps > 0]) if np.any(steps > 0) else None
    step_max = steps.max()

    if step_min is None or step_max <= step_min:
        raise ValueError("Steps must include at least two positive values to draw the fitted curve.")

    if log_x:
        dense_steps = np.logspace(np.log10(step_min), np.log10(step_max), n_interp)
    else:
        dense_steps = np.linspace(step_min, step_max, n_interp)

    return a * (dense_steps ** b), (dense_steps, a, b)


# Main series styles
ID_COLOR = "blue"
ID_LINESTYLE = "-"
OOD_COLOR = "orange"
OOD_LINESTYLE = "--"


# Build a linestyle map for baselines that avoids main styles "-" and "--"
def _baseline_styles(n: int):
    base = [
        "-.", ":", (0, (1, 1)), (0, (3, 1, 1, 1)),
        (0, (5, 2)), (0, (1, 2, 1, 2)), (0, (3, 2, 3, 2)),
        (0, (5, 1, 1, 1)), (0, (2, 1, 2, 1, 2, 1))
    ]
    while len(base) < n:
        k = len(base) + 1
        base.append((0, (max(1, (k % 6)), 1 + (k % 4))))
    return base[:n]


def plot_baselines(
    baseline_df: pd.DataFrame,
    id_tasks: List[str],
    ood_tasks: Optional[List[str]] = None,
    name_col: str = "model",
    ax: Optional[plt.Axes] = None,
    log_y: bool = False,
    # Baseline parameters
    show_baselines_in_legend: bool = True,
    annotate_baselines: bool = False,
) -> Tuple[plt.Figure, plt.Axes]:
    """
    Plot steps vs. performance with two series: In-Domain and (optionally) OOD.
    When `powerlaw_fit=True`, overlay a power-law fit y = a * x^b.
    When `powerlaw_fit=False`, connect the points with straight lines.
    If `ood_tasks` is None or no OOD columns are found, only In-domain is plotted.
    Baselines can be passed as a wide-format DataFrame and will adapt to the presence/absence of OOD.

    Notes
    -----
    Expects a helper `fit_powerlaw(y, x, log_x=False)` to be available in scope.
    """
    bl_id_cols = [c for c in id_tasks if c in baseline_df.columns]
    bl_ood_cols = [c for c in (ood_tasks or []) if c in baseline_df.columns]
    bl_all_cols = sorted(set(bl_id_cols) | set(bl_ood_cols))
    if not bl_all_cols:
        raise ValueError("No baseline task columns match id_tasks/ood_tasks in baseline_df.")

    # Baseline names
    if name_col in baseline_df.columns:
        base_names = baseline_df[name_col].astype(str).tolist()
    else:
        base_names = [str(i) for i in baseline_df.index]

    unique_names = list(dict.fromkeys(base_names))
    ls_pool = _baseline_styles(len(unique_names))
    linestyle_map = {nm: ls_pool[i] for i, nm in enumerate(unique_names)}

    def row_mean(row, cols):
        if not cols:
            return np.nan
        return pd.to_numeric(row[cols], errors="coerce").mean(skipna=True)

    for i, (_, row) in enumerate(baseline_df.iterrows()):
        base_name = base_names[i]
        bl_ls = linestyle_map[base_name]

        entries = []
        entries.append(("ID task", row_mean(row, bl_id_cols)))
        entries.append(("OOD task", row_mean(row, bl_ood_cols)))

        for kind, val in entries:
            if pd.isna(val):
                continue
            if log_y and (val <= 0):
                raise ValueError(
                    f"Baseline '{base_name}' ({kind}) has non-positive value {val} but log_y=True."
                )

            ax.axhline(
                y=float(val),
                linestyle=bl_ls, linewidth=1, alpha=0.9,
                label=(f"{base_name} ({kind})" if show_baselines_in_legend else "_nolegend_"),
                color=OOD_COLOR if kind == "OOD task" else ID_COLOR,
            )

            if annotate_baselines:
                label = f"{base_name} ({kind})"
                ax.annotate(
                    label, xy=(1.0, float(val)), xycoords=("axes fraction", "data"),
                    xytext=(4, 0), textcoords="offset points", va="center", ha="left",
                )

# ------------------------------ scaling ----------------------------- #
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional

def compute_relative_regret(
    df: pd.DataFrame,
    score_column: str = "score",
    baseline_column: str = "Tokens",
) -> pd.DataFrame:
    # Get the baseline values
    baseline_index = df[baseline_column].to_numpy().argmax()
    baseline_values = df[score_column].iloc[baseline_index]

    # Compute relative regret
    regret = 100 * (df[score_column] - baseline_values) / baseline_values

    return regret

MAX_COLUMNS = 5

def plot_scaling(
    df: pd.DataFrame,
    x: str,
    y_cols: str | List[str],
    hue: str = None,
    add_average_column: bool = False,
    baseline_df: Optional[pd.DataFrame] = None,
    sharex: bool = True,
    sharey: bool = True,
    logx: bool = False,
    xticks: Optional[List[float]] = None,
    show_legend: bool = True,
    **kwargs,
):
    """
    Plot subplots of (relative regret | mean score) per benchmark group vs `x`,
    one line per `hue`. If `baseline_df` is provided (rows = baselines; columns match
    benchmark names), plot horizontal lines for each baseline and include them in the legend.
    """
    if isinstance(y_cols, str):
        y_cols = [y_cols]

    if add_average_column:
        df = df.copy()
        df["Average"] = df[y_cols].mean(axis=1)
        if baseline_df is not None:
            # add average column to baseline_df as well
            baseline_df = baseline_df.copy()
            baseline_df["Average"] = baseline_df[y_cols].mean(axis=1)
        y_cols = y_cols + ["Average"]

    n_groups = len(y_cols)

    rows, cols = optimal_subplots_shape(n_groups, max_columns=MAX_COLUMNS)
    figsize = (14 * cols, 12 * rows + 4)

    fig, axes = plt.subplots(
        nrows=rows,
        ncols=cols,
        figsize=figsize,
        sharex=sharex,
        sharey=sharey,
        squeeze=False,
    )
    axes_list = axes.flatten()

    for ax, y in zip(axes_list, y_cols):
        sns.lineplot(
            data=df,
            x=x,
            y=y,
            hue=hue,
            style=hue,
            markers=MARKERS,
            ax=ax,
            marker="o" if hue is None else None,
            **kwargs
        )

        # plot baselines as horizontal lines
        if baseline_df is not None:
            ax = add_baselines(ax, baseline_df, [y])

        # set plot labels
        ax.set_title(f"\\textbf{{{y}}}")
        ax.set_xlabel(x)
        ax.set_ylabel("Score")
        ax.set_xticks(xticks or df[x].unique())
        if logx:
            # set to log2 scale
            ax.set_xscale("log", base=2)

        ax.grid(False)

    # hide unused axes and optimize space
    fig.tight_layout()
    for ax in axes_list[n_groups:]:
        ax.axis("off")

    
    # only one legend
    if hue is not None:
        # keep only labels in the data
        handles_labels = {
            l: h for h, l in zip(*axes_list[0].get_legend_handles_labels()) 
            if l in df[hue].unique() or l in baseline_df["model"].astype(str).tolist()
        }
        handles, labels = list(handles_labels.values()), list(handles_labels.keys())
        if handles:
            fig.legend(handles, labels, loc="lower center", title=None, ncol=min(4, len(handles)))
            for ax in axes_list:
                ax.get_legend().remove()
        plt.subplots_adjust(bottom=0.3)

    if not show_legend:
        for ax in axes_list:
            ax.get_legend().remove()
        fig.legend().remove()

    return fig, axes

def plot_heatmap(
        df,
        benchmark_groups: Optional[Dict[str, List[str]]] = None,
    ):

    """
    Plot a heatmap of mean scores per resolution and benchmark group.
    Expects a DataFrame with a 'resolution' column and one column per benchmark group.
    """
    pivot_table = df.pivot_table(index="resolution", values=benchmark_groups.keys(), aggfunc="mean")

    # Plot
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(
        pivot_table.T,
        annot=True,
        cmap="Blues",
        fmt=".1f",
        cbar_kws={'label': 'Score'},
        linewidths=0.5,
        linecolor="white",
    )

    # Clean style: remove axis labels
    ax.set_xlabel("")
    ax.set_ylabel("")

    # Adjust tick labels
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

    # Shift x-tick labels downward
    ax.tick_params(axis="x", pad=10)

    plt.tight_layout()

    return fig, ax


def plot_barplot(
        df,
        x: str,
        y: str,
        hue: str = None,
        add_average_column: bool = False,
        **kwargs
    ):
    n_groups = len(df[x].unique())
    if add_average_column:
        n_groups += 1
    fig = plt.figure(figsize=(10 * n_groups, 20))

    if add_average_column:
        # df is in long format, so we need to compute the average per x
        avg_df = df.groupby(hue)[y].mean().reset_index()
        avg_df[y] = avg_df[y].round(1)
        avg_df[x] = "Average"
        df = pd.concat([df, avg_df], axis=0, ignore_index=True)

    sns.barplot(
        data=df,
        x=x,        
        y=y, 
        hue=hue,      
        errorbar=None,
        **kwargs
    )

    # no x axis label
    plt.xlabel("")

    # legend below plot
    plt.legend(title=None, loc='lower center', ncol=n_groups, bbox_to_anchor=(0.5, -0.3))

    # save figure
    plt.tight_layout()

    return fig

def df_to_latex_table(
    df,
    index_column: str = "model",
    add_average_column: bool = True, 
    model_groups: dict = None,
    results_cols: list[str] = None,
    metadata_cols: list[str] = None,
    col_groups: dict = None,
    caption: str = "Results table",
    label: str = "tab:results",
) -> str:
    """
    Generate a LaTeX table from a DataFrame with grouped models and grouped columns.

    Args:
        df (pd.DataFrame): DataFrame containing model results.
        model_groups (dict): Mapping from group name -> list of model names.
        index_column (str): Column in df that contains model names.
        results_cols (list[str]): Columns with numeric results.
        metadata_cols (list[str]): Columns with metadata (e.g., Params).
        col_groups (dict): Mapping from group label -> list of column names.
        caption (str): Table caption.
        label (str): Table label for referencing.

    Returns:
        str: LaTeX table code.
    """
    if results_cols is None:
        results_cols = [c for c in df.columns if c != index_column]
    if metadata_cols is None:
        metadata_cols = []

    cols = metadata_cols + results_cols
    colspec = "@{} >{}l " + " ".join("c" for _ in cols)

    # compute average column if not already in df
    if add_average_column:
        df["\\textbf{Average}"] = df[results_cols].mean(axis=1, skipna=True)
        cols.append("\\textbf{Average}")
        colspec += " >{\\columncolor{avgcol}}c"

    # --- header row(s) ---
    header = ""
    if col_groups:
        # build the group row
        group_row = "& " + " & ".join("" for _ in metadata_cols)
        for i, (group, group_cols) in enumerate(col_groups.items()):
            span = len(group_cols)
            color = f"grp{chr(65 + i)}"  # default color
            group_row += f"\\multicolumn{{{span}}}{{>{{\\columncolor{{{color}}}}}c}}{{\\textbf{{{group}}}}} & "
        group_row = group_row.rstrip("& ") + " \\\\\n"
        # add midrule
        midrule_debut, midrule_end = 1 + len(metadata_cols) + 1, len(metadata_cols + results_cols) + 1
        group_row += "    \\cmidrule(lr){" + str(midrule_debut) + "-" + str(midrule_end) + "}\n"
        header += group_row

    # second row: actual column names
    header += "& " + " & ".join([f"\\rot{{{c}}}" for c in cols]) + " \\\\\n"

    body = ""

    def row_to_latex(row):
        row_values = [f"{v:.1f}" if isinstance(v, (float, int)) else str(v) if v is not None else "--" for v in row[cols]]
        return f"    {row[index_column]} & " + " & ".join(row_values) + " \\\\\n"

    if model_groups is None:
        body += f"\n    \\midrule\n"
        for _, row in df.iterrows():
            body += row_to_latex(row)
    else:
        for group_name, models in model_groups.items():
            body += f"\n    \\midrule\n"
            body += f"    \\multicolumn{{{len(cols)+1}}}{{@{{}}l}}{{\\itshape {group_name}}}\\\\\n"
            for model in models:
                body += row_to_latex(df[df[index_column] == model].iloc[0])

    # --- assemble ---
    latex = f"""\\begin{{table*}}[t]
  \\centering
  \\renewcommand{{\\arraystretch}}{{1.2}}
  \\small
  \\resizebox{{\\textwidth}}{{!}}{{%
  \\begin{{tabular}}{{{colspec}}}
    {header}{body}    
    \\bottomrule
  \\end{{tabular}}%
  }}
  \\caption{{{caption}}}
  \\label{{{label}}}
\\end{{table*}}
"""
    return latex