import re
import time
import click
import pathlib
import pandas as pd
import pickle
import numpy as np
import umap

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from matplotlib.lines import Line2D
from matplotlib.patches import Patch, Polygon
from matplotlib.colors import hsv_to_rgb
from utils.probe_attributes import *
from utils.subset_construction import *
from utils.phemotype_analysis import *

from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
from adjustText import adjust_text
from scipy.stats import spearmanr

from typing import Optional, Dict, List, Sequence, Tuple, Iterable
from sklearn.metrics import pairwise_distances

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']

TITLE_SIZE=30
LABEL_SIZE=26
TICK_SIZE=25
LEGEND_SIZE=25

color_list = [plt.get_cmap("tab20")(i) for i in range(20)]

def draw_models_tsne_row_ax(
    axes,
    all_results: dict,
    dataset_names: list[str],
    *,
    metrics: tuple[str,...] = ("RKSP","TYDF","BRDF","UQDF","SP"),
    models: list[str] | None = None,
    random_state: int = 0,
    perplexity: int | None = None,
    label_topk: int | None = None,
    metric: str = "cosine",         
    tsne_iters: int = 2500,         
):
    if hasattr(axes, "flat"):
        axes_list = list(axes.flat)
    elif isinstance(axes, (list, tuple)):
        axes_list = list(axes)
    else:
        axes_list = [axes]

    n = min(len(dataset_names), len(axes_list))
    for i in range(n):
        ds = dataset_names[i]
        ax = axes_list[i]

        bucket = all_results.get(ds, {})
        bench = bucket.get("scores", {}).get("benchmark_scores")
        if bench is None or not isinstance(bench, pd.DataFrame):
            ax.text(0.5, 0.5, "Data not available", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        df = bench.copy()
        if 'model_name' not in df.columns:
            idx_name = df.index.name or 'model_name'
            df = df.reset_index().rename(columns={idx_name: 'model_name'})

        need = ['model_name', *metrics]
        miss = [c for c in need if c not in df.columns]
        if miss:
            ax.text(0.5, 0.5, f"Missing cols:\n{miss}", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        if models is not None:
            df = df[df['model_name'].isin(models)]

        df = df.dropna(subset=list(metrics))
        if len(df) < 2:
            ax.text(0.5, 0.5, "Not enough models", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        X = df.loc[:, metrics].astype(float).copy()

        m = X.shape[0]
        perp = perplexity if perplexity is not None else max(5, min(40, (m - 1) // 3))

        xy = TSNE(
            n_components=2,
            perplexity=perp,
            init="pca",
            learning_rate="auto",
            random_state=random_state,
            n_iter=tsne_iters,
            early_exaggeration=20.0,
            metric=metric,         
            verbose=0
        ).fit_transform(X)

        ax.scatter(xy[:,0], xy[:,1], s=64, c="#1f77b4",
                   edgecolors="white", linewidths=0.6, alpha=0.9)

        labels = df['model_name'].astype(str).tolist()

        keep = np.ones(len(labels), dtype=bool)
        if label_topk is not None and label_topk < len(labels):
            from sklearn.neighbors import NearestNeighbors
            nbrs = NearestNeighbors(n_neighbors=2).fit(xy)
            d1 = nbrs.kneighbors(xy, return_distance=True)[0][:,1]
            order = np.argsort(-d1)
            keep = np.zeros(len(labels), dtype=bool)
            keep[order[:label_topk]] = True

        texts = []
        for (x, y), lab, k in zip(xy, labels, keep):
            texts.append(ax.text(x, y, lab if k else "", fontsize=9,
                                 ha="center", va="center"))

        adjust_text(
            texts, x=xy[:,0], y=xy[:,1], ax=ax,
            expand_points=(1.35, 1.45), expand_text=(1.15, 1.25),
            force_text=0.9, force_points=0.25,
            only_move={"points":"y","text":"xy"},
            arrowprops=dict(arrowstyle="-", lw=0.6, color="0.35", alpha=0.7),
            lim=300, precision=0.01
        )

        ax.set_title(ds, fontsize=TITLE_SIZE)
        ax.tick_params(axis='both', which='major', labelsize=TICK_SIZE)
        ax.grid(False)

def plot_quadrant_ax(ax: plt.Axes, df: pd.DataFrame, x_col: str, y_col: str, title: str = "Meme Quadrant Plot") -> plt.Axes:
    x = df[x_col]
    y = df[y_col]
    
    ax.scatter(x, y, c=color_list[2], edgecolors=color_list[2], alpha=0.7, s=80)
    
    ax.set_xlabel(x_col, fontsize=LABEL_SIZE)
    ax.set_ylabel(y_col, fontsize=LABEL_SIZE)
    ax.set_title(title, fontsize=TITLE_SIZE)
    ax.tick_params(axis='both', which='major', labelsize=TICK_SIZE)
    ax.grid(True, linestyle=':', alpha=0.6)
    
    return ax

def plot_phemotype_share_pies(
    all_calculate_results_dict: dict,
    *,
    dims=("RKSP","TYDF","BRDF","UQDF","SP"),
    title="Phemotype Composition (Average on All Datasets)",
    figsize=(12, 4),
    dpi=300,
    stage_order=("Base","CoT","IR"),
    stage_regex=r"\((IR|CoT)\)",
    color_map=None,
    label_map=None,
    save_path: str | None = None
):
    bucket = all_calculate_results_dict.get("_global", {}).get("scores", {})
    model_avg = bucket.get("model_avg_scores", None)

    df = model_avg.copy()
    if "model_name" in df.columns:
        df = df.set_index("model_name")

    dims = tuple([d for d in dims if d in df.columns])
    df = df.dropna(subset=list(dims))

    def _stage_from_name(name: str) -> str:
        n = str(name)
        m = re.search(stage_regex, n)
        if m:
            tag = m.group(1)
            if tag == "IR": return "IR"
            if tag == "CoT": return "CoT"
        return "Base"

    df["stage"] = [ _stage_from_name(i) for i in df.index ]

    stage_means  = df.groupby("stage")[list(dims)].mean()
    stage_counts = df.groupby("stage").size()

    for st in stage_order:
        if st not in stage_means.index:
            stage_means.loc[st] = 0.0
            stage_counts.loc[st] = 0
    stage_means  = stage_means.loc[list(stage_order)]
    stage_counts = stage_counts.loc[list(stage_order)]

    stage_share = stage_means.div(stage_means.sum(axis=1).replace(0,1), axis=0)

    default_color_map = {
        "RKSP": "#ff7f0e", 
        "TYDF": "#2ca02c", 
        "BRDF": "#9467bd", 
        "UQDF": "#8c564b", 
        "SP":   "#e377c2", 
    }
    color_map = color_map or default_color_map
    label_map = label_map or {
        "RKSP":"Vigilance", "TYDF":"Mastery", "BRDF":"Transfer",
        "UQDF":"Ingenuity", "SP":"Astuteness"
    }

    fig, axes = plt.subplots(
        1, 4, figsize=figsize, dpi=dpi,
        gridspec_kw={"width_ratios":[1,1,1,0.4], "wspace":0.01} 
    )
    ax_base, ax_cot, ax_ir, ax_leg = axes

    stage_axes = {
        stage_order[0]: ax_base,
        stage_order[1]: ax_cot,
        stage_order[2]: ax_ir,
    }

    dims_list    = list(dims)
    legend_pchs = [plt.Rectangle((0,0),1,1, color=color_map.get(d)) for d in dims_list]
    legend_labs = [label_map.get(d, d) for d in dims_list]

    ax_leg.axis("off")

    for st, ax in stage_axes.items():
        shares = stage_share.loc[st].values
        colors = [color_map.get(d, "#999999") for d in dims_list]

        wedges, texts, autotexts = ax.pie(
            shares,
            labels=None,
            autopct=lambda p: f"{p:.0f}%" if p >= 6 else "",
            startangle=90, counterclock=False,
            colors=colors, wedgeprops={"edgecolor":"white", "linewidth": 1.5}
        )
        for t in autotexts:
            t.set_fontsize(18)
            t.set_color("black")

        n_models = int(stage_counts.loc[st])
        ax.text(0.5, -0.12, f"{st} ({n_models} models)",
                ha="center", va="center", transform=ax.transAxes,
                fontsize=22)

    ax_leg.legend(
        legend_pchs, legend_labs,
        ncol=1, frameon=False, loc="center left",
        bbox_to_anchor=(-0.1, 0.5),
        fontsize=20, handlelength=1.4,
    )

    fig.subplots_adjust(
        right=0.92,
        bottom=0.1,
        top=0.8,     
        wspace=0.03
    )
    
    if save_path:
        fig.savefig(save_path, bbox_inches="tight", pad_inches=0.05)

    return fig, (ax_base, ax_cot, ax_ir), stage_share

def plot_quadrant(df: pd.DataFrame, x_col: str, y_col: str, title: str = "Meme Quadrant Plot") -> plt.Figure:
    fig, ax = plt.subplots(figsize=(12, 9))
    plot_quadrant_ax(ax, df, x_col, y_col, title)
    return fig

def plot_probe_traits_3d_ax(ax, metrics_df,
                            x_col="difficulty",
                            y_col="risk",
                            z_col="uniqueness",
                            title="Probe Traits 3D (Stylish RGB)"):

    xs = metrics_df[x_col].values
    ys = metrics_df[y_col].values
    zs = metrics_df[z_col].values

    xs_norm = (xs - xs.min()) / (xs.max() - xs.min()) if xs.max() > xs.min() else xs
    ys_norm = (ys - ys.min()) / (ys.max() - ys.min()) if ys.max() > ys.min() else ys
    zs_norm = (zs - zs.min()) / (zs.max() - zs.min()) if zs.max() > zs.min() else zs
    colors = np.stack([xs_norm, ys_norm, zs_norm], axis=1)

    ax.scatter(xs, ys, zs,
               c=colors, s=70,
               alpha=0.85, edgecolors="k", linewidths=0.3)

    ax.set_xlabel(x_col, fontsize=LABEL_SIZE, labelpad=17)
    ax.set_ylabel(y_col, fontsize=LABEL_SIZE, labelpad=17)
    ax.set_zlabel(z_col, fontsize=LABEL_SIZE, labelpad=11)
    ax.text2D(0.5, 0.965, title, transform=ax.transAxes,
              ha='center', va='top', fontsize=TITLE_SIZE-3)

    ax.xaxis.pane.set_edgecolor('w')
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')
    ax.xaxis.pane.set_alpha(0.1)
    ax.yaxis.pane.set_alpha(0.1)
    ax.zaxis.pane.set_alpha(0.1)
    ax.tick_params(axis='both', which='major', labelsize=TICK_SIZE-1)
    ax.grid(color="gray", linestyle="--", linewidth=0.5, alpha=0.5)

    xticks = ax.get_xticks()
    yticks = ax.get_yticks()
    zticks = ax.get_zticks()
    if len(xticks) > 2:
        ax.set_xticks(xticks[1:-1])
    if len(yticks) > 2:
        ax.set_yticks(yticks[1:-1])
    if len(zticks) > 2:
        ax.set_zticks(zticks[1:-1])

    ax.view_init(elev=25, azim=40)

    return ax

FALLBACK_COLOR = "#B7B7B7" 

def _top20_tab20_colors_for_embeddings(labels):
    labels = np.asarray(labels)
    uniq, counts = np.unique(labels, return_counts=True)
    order = np.argsort(counts)[::-1]
    top20 = uniq[order[:20]]
    
    tab20 = plt.get_cmap("tab20", 20)
    tab20_colors = [tab20(i) for i in range(20)]

    color_map = {}
    for rank, cid in enumerate(top20):
        color_map[int(cid)] = tab20_colors[rank % 20]
    for cid in uniq:
        if int(cid) not in color_map:
            color_map[int(cid)] = FALLBACK_COLOR

    return color_map, list(top20)

def _mds_from_dist(D, n_components=2):
    Q = D.shape[0]
    J = np.eye(Q) - np.ones((Q, Q))/Q
    B = -0.5 * J @ (D**2) @ J
    w, V = np.linalg.eigh(B)
    order = np.argsort(w)[::-1]
    w = np.clip(w[order[:n_components]], 0, None)
    V = V[:, order[:n_components]]
    return V * np.sqrt(w)

def _to_hamming_dist(X_or_S):
    A = np.asarray(X_or_S)
    n, m = A.shape
    if n == m:
        if np.allclose(A, A.T, atol=1e-7):
            diag = np.diag(A)
            if np.allclose(diag, 0.0, atol=1e-6):
                return A.astype(float)
            if np.allclose(diag, 1.0, atol=1e-6):
                return (1.0 - A).astype(float)
        Xbin = (A > 0.5).astype(np.uint8)
        return pairwise_distances(Xbin, metric="hamming").astype(float)
    Xbin = (A > 0.5).astype(np.uint8)
    return pairwise_distances(Xbin, metric="hamming").astype(float)

def plot_embedding_ax(ax: plt.Axes, X_embed: np.ndarray, labels: np.ndarray, color_map: dict, title: str,
                      point_size: int = 80, alpha: float = 0.85) -> plt.Axes:
    for c in np.unique(labels):
        m = (labels == c)
        ax.scatter(X_embed[m, 0], X_embed[m, 1],
                   s=point_size, c=[color_map.get(int(c), FALLBACK_COLOR)],
                   alpha=alpha, edgecolors='none', rasterized=True)
    
    ax.set_title(title, fontsize=TITLE_SIZE, pad=22)
    ax.grid(True, linestyle="--", alpha=0.2)
    ax.tick_params(axis='both', which='major', labelsize=TICK_SIZE+1)
    
    return ax

def calculate_all_embeddings(X_or_S, random_state=123, umap_n_neighbors=15, umap_min_dist=0.1, perplexity=20):
    D = _to_hamming_dist(X_or_S)
    
    X_tsne = TSNE(n_components=2, metric="precomputed",
                  perplexity=perplexity, init="random",
                  random_state=random_state).fit_transform(D)
    
    reducer = umap.UMAP(n_components=2, metric="precomputed",
                        n_neighbors=umap_n_neighbors, min_dist=umap_min_dist,
                        random_state=random_state)
    X_umap = reducer.fit_transform(D)
    
    X_pca = _mds_from_dist(D, 2)
    
    coords = {
        "tsne": X_tsne,
        "umap": X_umap,
        "pca": X_pca,
    }
    return coords

def plot_similarity_heatmap_minimal_ax(
    ax: plt.Axes,
    S: np.ndarray,
    result_df,
    show_blocks: bool = True,
    add_colorbar: bool = True,
    cax: Optional[plt.Axes] = None,
    vmin: float = 0.0,
    vmax: float = 1.0,
    cmap: str = "viridis",
    title: str = "Similarity Heatmap",
):
    labels = result_df["cluster"].astype(int).values
    rep = result_df.get("typicality", np.zeros(len(labels)))

    order = np.lexsort((-rep, labels))
    S_ord = S[np.ix_(order, order)]
    labels_ord = labels[order]

    im = ax.imshow(S_ord, vmin=vmin, vmax=vmax, cmap=cmap, interpolation="nearest")

    if add_colorbar:
        fig = ax.get_figure()
        if cax is not None:
            cbar = fig.colorbar(im, cax=cax)
        else:
            cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        cbar.set_label("Similarity")

    if show_blocks:
        bounds = []
        for c in np.unique(labels_ord):
            idx = np.where(labels_ord == c)[0]
            bounds.append((idx[0], idx[-1]))
        for s, e in bounds:
            ax.axhline(s-0.5, color="white", lw=1)
            ax.axhline(e+0.5, color="white", lw=1)
            ax.axvline(s-0.5, color="white", lw=1)
            ax.axvline(e+0.5, color="white", lw=1)

    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title, fontsize=TITLE_SIZE, pad=22)

    return ax, order, im

def draw_phemotype_radar_by_mode_row_ax(
    axes,
    *,
    all_results: dict,
    dataset_names: list[str],
    model_col: str = "model_name",
    metrics: tuple[str, ...] = ("RKSP","TYDF","BRDF","UQDF","SP"),
    normalize: bool = False,
    fill: bool = False,
    linewidth: float = 2.0,
    markers: bool = True,
    legend_on: bool = True,
    title: Optional[str] = None,
):
    if hasattr(axes, "flat"):
        axes_list = list(axes.flat)
    elif isinstance(axes, (list, tuple)):
        axes_list = list(axes)
    else:
        axes_list = [axes]

    _label_map = {
        "RKSP": "Vigilance",
        "TYDF": "Mastery",
        "BRDF": "Transfer",
        "UQDF": "Ingenuity",
        "SP":   "Astuteness",
    }
    tick_labels = [ _label_map.get(m, m) for m in metrics ]

    K = len(metrics)
    angles = np.linspace(0, 2*np.pi, K, endpoint=False)
    angles_closed = np.concatenate([angles, angles[:1]])

    color_cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['C0','C1','C2'])
    cat_colors = {
        "IR":    (color_cycle[0] if len(color_cycle)>0 else 'C0'),
        "CoT":   (color_cycle[1] if len(color_cycle)>1 else 'C1'),
        "Other": (color_cycle[2] if len(color_cycle)>2 else 'C2'),
    }

    def _classify(name: str) -> str:
        m = re.search(r"\(([^()]*)\)\s*$", str(name))
        tag = (m.group(1).strip().upper() if m else "")
        if tag == "IR":  return "IR"
        if tag == "CoT": return "CoT"
        return "Other"

    def _minmax01(mat: pd.DataFrame) -> pd.DataFrame:
        mn, mx = mat.min(0), mat.max(0)
        den = (mx - mn).replace(0, 1.0)
        return (mat - mn) / den

    used_data: dict[str, pd.DataFrame] = {}
    n = min(len(dataset_names), len(axes_list))

    for i in range(n):
        ds = dataset_names[i]
        ax = axes_list[i]

        bucket = all_results.get(ds, {})
        df = _get_model_scores_df(bucket) 
        if df is not None and model_col not in df.columns:
            df[model_col] = df.index.astype(str)
        if df is None or df.empty:
            ax.text(0.5, 0.5, f"No model scores for {ds}", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        miss = [c for c in metrics if c not in df.columns]
        if model_col not in df.columns or miss:
            ax.text(0.5, 0.5, f"Missing columns in {ds}: { [model_col] + miss }", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        tmp = df.copy()
        tmp["_cat_"] = tmp[model_col].map(_classify)
        grp = tmp.groupby("_cat_", as_index=False)[list(metrics)].mean(numeric_only=True)
        if grp.empty:
            ax.text(0.5, 0.5, f"No categories in {ds}", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        M = grp.set_index("_cat_").astype(float)
        if normalize:
            M = _minmax01(M)

        ax.set_xticks(angles)
        ax.set_xticklabels(tick_labels, fontsize=LABEL_SIZE-2)
        if (M.values >= 0).all():
            ax.set_rmin(0.0)
        ax.tick_params(axis='both', which='major', labelsize=TICK_SIZE)

        handles, labels = [], []
        for cat in ["IR", "CoT", "Other"]:
            if cat not in M.index:
                continue
            vals = M.loc[cat, list(metrics)].values.astype(float)
            vals_c = np.concatenate([vals, vals[:1]])
            h, = ax.plot(angles_closed, vals_c, linewidth=linewidth,
                         marker=('o' if markers else None),
                         color=cat_colors[cat], label=cat)
            if fill:
                ax.fill(angles_closed, vals_c, alpha=0.12, color=cat_colors[cat])
            handles.append(h); labels.append(cat)

        ax.set_title(title if title is not None else ds, pad=12, fontsize=TITLE_SIZE-2)

        if legend_on and i == 0 and handles:
            ax.legend(handles, labels, loc="upper right", ncol=1, frameon=True, fontsize=max(LEGEND_SIZE-4, 6))

        used_data[ds] = M.reset_index()

    return used_data

def _get_model_scores_df(bucket: dict) -> Optional[pd.DataFrame]:
    if not isinstance(bucket, dict):
        return None
    scores_dict = bucket.get("scores", {}) if isinstance(bucket.get("scores", {}), dict) else {}
    for key in ("benchmark_scores", "model_scores", "model_avg_scores"):
        df = scores_dict.get(key)
        if isinstance(df, pd.DataFrame):
            if "model_name" in df.columns:
                df = df.set_index("model_name")
            return df.copy()
    return None

def draw_metric_lines_over_models_row_ax(
    axes,
    *,
    all_results: dict,
    dataset_names: list[str],
    metrics: tuple[str, ...] = ("ACC","RKSP","TYDF","BRDF","UQDF","SP","Composite"),
    order_by: str | None = "ACC",
    y_range: tuple[float, float] | None = None,
    legend_on: bool = False,
    line_width: float = 2.0,
    tick_rotation: int = 60,
    title: Optional[str] = None,
    layout: str = "row", 
):
    _label_map = {
        "RKSP": "Vigilance",
        "TYDF": "Mastery",
        "BRDF": "Transfer",
        "UQDF": "Ingenuity",
        "SP":   "Astuteness",
        "Composite": "Phemotype Composite Score",
    }

    axes_list = None
    try:
        import numpy as _np
        if isinstance(axes, _np.ndarray):
            order = "C" if layout.lower() == "row" else "F"
            axes_list = list(axes.ravel(order=order))
        elif hasattr(axes, "flat"):
            if layout.lower() == "col" and getattr(getattr(axes, "shape", None), "__len__", lambda:0)() == 2:
                a = _np.array(list(axes.flat))
                rows, cols = getattr(axes, "shape", (1, len(a)))
                axes_list = [axes[r, c] for c in range(cols) for r in range(rows)]
            else:
                axes_list = list(axes.flat)
        elif isinstance(axes, (list, tuple)):
            axes_list = list(axes)
        else:
            axes_list = [axes]
    except Exception:
        if hasattr(axes, "flat"):
            axes_list = list(axes.flat)
        elif isinstance(axes, (list, tuple)):
            axes_list = list(axes)
        else:
            axes_list = [axes]

    used_data: dict[str, pd.DataFrame] = {}
    n = min(len(dataset_names), len(axes_list))

    for i in range(n):
        ds = dataset_names[i]
        ax = axes_list[i]
        bucket = all_results.get(ds, {})
        df = _get_model_scores_df(bucket)

        if df is None or df.empty:
            ax.text(0.5, 0.5, f"No model scores for {ds}", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        cols = [c for c in metrics if c in df.columns]
        _priority = ["ACC", "Composite"]
        cols = [c for c in _priority if c in cols] + [c for c in cols if c not in _priority]
        if len(cols) == 0:
            ax.text(0.5, 0.5, f"No requested metrics in {ds}", ha="center", va="center", transform=ax.transAxes)
            ax.set_axis_off()
            continue

        lines_colors = [color_list[0], color_list[6], color_list[2], color_list[4], color_list[8], color_list[10], color_list[12]]

        df_plot = df.copy()
        if order_by is not None and order_by in df_plot.columns:
            df_plot = df_plot.sort_values(by=order_by, ascending=False)

        df_plot = df_plot[cols]

        x = np.arange(len(df_plot.index))
        model_names = list(df_plot.index)

        for j, m in enumerate(cols):
            y = df_plot[m].astype(float).values
            display_label = _label_map.get(m, m)
            color = lines_colors[j % len(lines_colors)]
            ax.plot(x, y, label=display_label, linewidth=line_width, color=color)

        title_to_use = title if title is not None else ds
        ax.set_title(title_to_use, fontsize=TITLE_SIZE+2, pad = 10)
        ax.set_xticks(x)
        ax.set_xticklabels(model_names, rotation=tick_rotation, ha='right')

        ax.set_ylim(0, 1)
        ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y*100:.0f}"))
        if y_range is not None:
            ax.set_ylim(*y_range)
        
        ax.grid(True, alpha=0.35)
        ax.set_ylabel("Metric Score", fontsize=LABEL_SIZE+1)
        ax.set_xlim(-0.5, len(df_plot.index) - 0.5)
        ax.tick_params(axis='y', which='major', pad=3, labelsize=TICK_SIZE-2) 
        ax.tick_params(axis='x', labelsize=TICK_SIZE-4 , pad=1)
        
        ax.legend(frameon=False, ncol=min(4, len(cols)), loc="best",
                  fontsize=LABEL_SIZE-3, handletextpad=1, columnspacing=1)

        used_data[ds] = df_plot

    return used_data

from string import ascii_lowercase

def _add_panel_letters(axes, labels=None, x=0.5, y=-0.14,
                       ha="center", va="top",
                       fontsize=LABEL_SIZE+2, fontweight="bold",
                       color="black"):
    if hasattr(axes, "flat"):
        axes_list = list(axes.flat)
    elif isinstance(axes, (list, tuple)):
        axes_list = list(axes)
    else:
        axes_list = [axes]

    if labels is None:
        labels = [f"({ascii_lowercase[i]})" for i in range(len(axes_list))]

    for ax, lab in zip(axes_list, labels):
        text2D = getattr(ax, "text2D", None)  
        if callable(text2D):
            text2D(x, y, lab, transform=ax.transAxes,
                   ha=ha, va=va, fontsize=fontsize,
                   fontweight=fontweight, color=color,
                   clip_on=False)
        else:
            ax.text(x, y, lab, transform=ax.transAxes,
                    ha=ha, va=va, fontsize=fontsize,
                    fontweight=fontweight, color=color,
                    clip_on=False)

def _save_png_and_pdf(fig, png_path, pdf_path=None, dpi=300, **savefig_kwargs):
    fig.savefig(png_path, dpi=dpi, **savefig_kwargs)
    if pdf_path is not None:
        fig.savefig(pdf_path, dpi=dpi, **savefig_kwargs)

def plot_combined_visualizations(all_results: dict, all_dataset_names: list,
                                 output_dir: pathlib.Path, pdf_dir: Optional[pathlib.Path] = None):
    from itertools import combinations  

    dataset_names = all_dataset_names
    print(all_dataset_names)
    num_datasets = len(dataset_names)
    if num_datasets == 0:
        print("No datasets found to visualize.")
        return

    print(f"\nGenerating combined visualizations for {num_datasets} datasets: {', '.join(dataset_names)}")

    fig_lines_global, ax = plt.subplots(1, 1, figsize=(22, 8), dpi=300,)

    _ = draw_metric_lines_over_models_row_ax(
        ax,                                        
        all_results=all_results,    
        dataset_names=["_global"],               
        metrics=("ACC","RKSP","TYDF","BRDF","UQDF","SP","Composite"),  
        order_by="ACC",                     
        y_range=None,                           
        legend_on=True,
        line_width=2.4,
        tick_rotation=60,
        title="Accuracy v.s. Phemotypes (Averaged Over All Datasets)",
    )
    
    _save_png_and_pdf(
        fig_lines_global,
        output_dir / "global_metric_lines.png",
        (pdf_dir / "global_metric_lines.pdf") if pdf_dir is not None else None,
        dpi=300,
        bbox_inches='tight',
    )
    plt.close(fig_lines_global)

    fig_lines, axes_lines = plt.subplots(num_datasets, 1, figsize=(20 , 10 * num_datasets),
                                         gridspec_kw={"hspace": 0.95}, squeeze=False, dpi=300)

    _ = draw_metric_lines_over_models_row_ax(
        axes_lines,
        all_results=all_results,
        dataset_names=dataset_names[:num_datasets],
        metrics=("ACC","RKSP","TYDF","BRDF","UQDF","SP","Composite"),
        order_by="ACC",     
        y_range=None,             
        legend_on=False,
        line_width=2.4,
        tick_rotation=60,
        layout="col",  
    )

    _save_png_and_pdf(
        fig_lines,
        output_dir / "combined_metric_lines.png",
        (pdf_dir / "combined_metric_lines.pdf") if pdf_dir is not None else None,
        dpi=300,
        bbox_inches='tight',
    )
    plt.close(fig_lines)

    fig_tsne, axes_tsne = plt.subplots(
        1, num_datasets, figsize=(16 * num_datasets, 12),
        squeeze=False, dpi=300
    )

    models_results = draw_models_tsne_row_ax(
        axes_tsne[0],
        all_results=all_results,
        dataset_names=dataset_names[:num_datasets],
    )

    _add_panel_letters(axes_tsne[0], y=-0.19)
    fig_tsne.subplots_adjust(bottom=0.185, wspace=0.1)

    _save_png_and_pdf(
        fig_tsne,
        output_dir / "combined_models_tsne.png",
        (pdf_dir / "combined_models_tsne.pdf") if pdf_dir is not None else None,
        dpi=300,
        bbox_inches='tight',
    )
    plt.close(fig_tsne)

    fig_embed, axes_embed = plt.subplots(1, num_datasets, figsize=(10 * num_datasets, 8), squeeze=False)
    for i, ds_name in enumerate(dataset_names):
        ax = axes_embed[0, i]
        cluster_results = all_results[ds_name].get('clustering_results')
        if cluster_results:
            S = cluster_results['similarity']
            result_df = cluster_results['result_df']
            labels = result_df['cluster'].astype(int).values
            coords = calculate_all_embeddings(S)
            color_map, _ = _top20_tab20_colors_for_embeddings(labels)
            plot_embedding_ax(ax, coords['tsne'], labels, color_map, title=f"{ds_name}", point_size=30)
        else:
            ax.text(0.5, 0.5, "Data not available", ha='center', va='center')
    _add_panel_letters(axes_embed[0], y=-0.19)
    fig_embed.subplots_adjust(bottom=0.185)

    _save_png_and_pdf(
        fig_embed,
        output_dir / "combined_embedding_plots.png",
        (pdf_dir / "combined_embedding_plots.pdf") if pdf_dir is not None else None,
        dpi=300,
        bbox_inches='tight',
    )
    plt.close(fig_embed)

    vmin, vmax, cmap = 0.0, 1.0, "viridis"

    fig_hm_row = plt.figure(figsize=(8 * num_datasets + 1.5, 8), dpi=300)
    gs = fig_hm_row.add_gridspec(1, num_datasets + 1,
                                 width_ratios=[1]*num_datasets + [0.045],
                                 wspace=0.12)

    im_last = None
    axes_hm = []
    for i, ds_name in enumerate(dataset_names):
        ax = fig_hm_row.add_subplot(gs[0, i])
        axes_hm.append(ax)
        cluster_results = all_results[ds_name].get('clustering_results')
        if cluster_results:
            S = cluster_results['similarity']
            result_df = cluster_results['result_df']
            ax.set_title(f"{ds_name}", fontsize=TITLE_SIZE)
            _, _, im = plot_similarity_heatmap_minimal_ax(
                ax, S, result_df,
                show_blocks=False, add_colorbar=False,
                vmin=vmin, vmax=vmax, cmap=cmap, title=ds_name
            )
            im_last = im
        else:
            ax.text(0.5, 0.5, "Data not available", ha='center', va='center', transform=ax.transAxes)

    cax = fig_hm_row.add_subplot(gs[0, -1])
    if im_last is not None:
        cbar = fig_hm_row.colorbar(im_last, cax=cax)
        cbar.set_label("Similarity", fontsize=LABEL_SIZE, labelpad=12)
        cbar.ax.tick_params(labelsize=TICK_SIZE)

    _add_panel_letters(axes_hm, y=-0.08)  
    fig_hm_row.subplots_adjust(bottom=0.12)

    _save_png_and_pdf(
        fig_hm_row,
        output_dir / "combined_cluster_heatmaps.png",
        (pdf_dir / "combined_cluster_heatmaps.pdf") if pdf_dir is not None else None,
        dpi=300,
        bbox_inches='tight',
    )
    plt.close(fig_hm_row)

    print("\nGenerating all 2D and 3D metric combination plots...")
    metrics_to_combine = ['difficulty', 'uniqueness', 'risk', 'surprise', 'typicality', 'bridge']

    from itertools import combinations
    for x_col, y_col, z_col in combinations(metrics_to_combine, 3):
        fig_3d_comb = plt.figure(figsize=(8 * num_datasets, 9))
        axes3 = []
        for i, ds_name in enumerate(dataset_names):
            ax = fig_3d_comb.add_subplot(1, num_datasets, i + 1, projection='3d')
            axes3.append(ax)
            full_metrics = all_results[ds_name].get('full_metrics')
            if full_metrics is not None and all(c in full_metrics.columns for c in [x_col, y_col, z_col]):
                plot_probe_traits_3d_ax(ax, full_metrics,
                                        x_col=x_col, y_col=y_col, z_col=z_col,
                                        title=f"{ds_name}")
            else:
                ax.text2D(0.5, 0.5, f"Data or columns\n({x_col}, {y_col}, {z_col})\nnot available",
                          ha='center', va='center', transform=ax.transAxes, fontsize=16)

        _add_panel_letters(axes3, y=-0.05)
        fig_3d_comb.subplots_adjust(bottom=0.07, wspace=0.085)
        axes3[0].text2D(-0.10, 0.5, " ", transform=axes3[0].transAxes)
        filename_png = f"combined_3d_{x_col}_{y_col}_{z_col}.png"
        filename_pdf = f"combined_3d_{x_col}_{y_col}_{z_col}.pdf"
        _save_png_and_pdf(
            fig_3d_comb,
            output_dir / filename_png,
            (pdf_dir / filename_pdf) if pdf_dir is not None else None,
            dpi=300,
            bbox_inches='tight',
        )
        plt.close(fig_3d_comb)

    print("All combined visualization figures have been saved .")

@click.command()
@click.option("--all-results-file", required=True, type=click.Path(path_type=pathlib.Path, exists=True), help="All results file containing model answer data.")
@click.option("--output-dir", required=True, type=click.Path(path_type=pathlib.Path), help="Output directory to save results.")
@click.option("--recompute", is_flag=True, default=False, help="Force recomputation and ignore cached files.")
def main(all_results_file: pathlib.Path, output_dir: pathlib.Path, recompute: bool):
    fig_dir = output_dir.joinpath("figures")
    meme_dir = output_dir.joinpath("meme")
    debug_dir = output_dir.joinpath("debug")
    pdf_dir = output_dir.joinpath("pdfs")

    fig_dir.mkdir(parents=True, exist_ok=True)
    meme_dir.mkdir(parents=True, exist_ok=True)
    debug_dir.mkdir(parents=True, exist_ok=True)
    pdf_dir.mkdir(parents=True, exist_ok=True)

    all_results_df = pd.read_csv(all_results_file)
    all_benchmark_scores_list = list()
    all_calculate_results_dict = dict()

    for dataset_name, group_df in all_results_df.groupby('dataset'):
        print("=" * 50)
        all_calculate_results_dict.setdefault(dataset_name, {
            "dataset": dataset_name,
            "matrices": {},
            "scores": {}
        })
        ds_bucket = all_calculate_results_dict[dataset_name]

        model_order = group_df['model_name'].drop_duplicates().tolist()
        error_matrix = group_df.pivot_table(
            index='question_hash', columns='model_name', values='error', fill_value=0
        )

        available_models = [m for m in model_order if m in error_matrix.columns]
        error_matrix = error_matrix[available_models]

        mask_not_all_correct = ~(error_matrix.sum(axis=1) == 0)
        mask_not_all_wrong = ~(error_matrix.sum(axis=1) == error_matrix.shape[1])

        num_all_correct = (error_matrix.sum(axis=1) == 0).sum()
        num_all_wrong = (error_matrix.sum(axis=1) == error_matrix.shape[1]).sum()
        
        error_matrix = error_matrix[mask_not_all_correct & mask_not_all_wrong]
        correct_matrix = 1 - error_matrix

        model_accuracies = 1 - error_matrix.mean(axis=0)
        print(model_accuracies)

        ds_bucket["matrices"] = {
            "error_matrix":       error_matrix.copy(),
            "correct_matrix":     correct_matrix.copy(),
        }

        metrics_path = meme_dir.joinpath(f"{dataset_name}_metrics.csv")
        if (not recompute) and metrics_path.exists():
            metrics = pd.read_csv(metrics_path, index_col=0)
            dbg_path = meme_dir.joinpath(f"{dataset_name}_metrics_debug.csv")
            debug_metrics = pd.read_csv(dbg_path, index_col=0) if dbg_path.exists() else None
        else:
            print("now computing probe attrs")
            metrics, debug_metrics = analyze_test_items(correct_matrix)
            metrics.to_csv(metrics_path, index=True)
            (meme_dir / f"{dataset_name}_metrics_debug.csv").write_text("") \
                if debug_metrics is None else debug_metrics.to_csv(meme_dir.joinpath(f"{dataset_name}_metrics_debug.csv"), index=True)
        
        start = time.time()
        rc_br_pkl_path = meme_dir.joinpath(f"{dataset_name}_rep_bridge_communities.pkl")
        if (not recompute) and rc_br_pkl_path.exists():
            with open(rc_br_pkl_path, "rb") as f:
                RC_BR_all_result = pickle.load(f)
        else:
            RC_BR_all_result = compute_clustering_metrics(correct_matrix, return_all=True)
            with open(rc_br_pkl_path, "wb") as f:
                pickle.dump(RC_BR_all_result, f)

        if "similarity" not in RC_BR_all_result and "similarity_matrix" in RC_BR_all_result:
            RC_BR_all_result["similarity"] = RC_BR_all_result["similarity_matrix"]

        RC_BR_result = RC_BR_all_result["result_df"]
        RC_BR_result.to_csv(meme_dir.joinpath(f"{dataset_name}_rep_bridge_communities.csv"), index=True)
        print(f"Representativeness & Bridge computed in {time.time() - start:.3f}s")

        full_metrics = metrics.copy()
        if debug_metrics is not None:
            for col in ["surprise", "surprise_error", "surprise_success"]:
                if col in debug_metrics.columns:
                    full_metrics[col] = debug_metrics[col]
        for col in RC_BR_result.columns:
            full_metrics[col] = RC_BR_result[col]
        for model in correct_matrix.columns:
            full_metrics[model] = correct_matrix[model]
        full_metrics.to_csv(meme_dir.joinpath(f"{dataset_name}_metrics_full.csv"), index=True)

        ds_bucket['full_metrics'] = full_metrics.copy()
        ds_bucket['clustering_results'] = RC_BR_all_result.copy()

        benchmark_scores, item_weights, spearman_corr = calculate_metrics(
            correct_matrix, full_metrics, X_full=correct_matrix
        )
        benchmark_scores.to_csv(meme_dir.joinpath(f"{dataset_name}_benchmark_scores.csv"), index=True)
        item_weights.to_csv(meme_dir.joinpath(f"{dataset_name}_item_weights.csv"), index=True)
        spearman_corr.to_csv(meme_dir.joinpath(f"{dataset_name}_spearman_corr.csv"), index=True)

        ds_bucket["scores"].update({
            "model_accuracies": model_accuracies.copy(),
            "benchmark_scores": benchmark_scores.copy(),
            "item_weights":     item_weights.copy()
        })
        all_benchmark_scores_list.append((dataset_name, benchmark_scores))

        selected_ids, manifest, acc_table = build_subset_by_attributes(correct_matrix, full_metrics, k=100)
        manifest.to_csv(meme_dir.joinpath(f"{dataset_name}_subset_manifest.csv"), index=True)
        acc_table.to_csv(meme_dir.joinpath(f"{dataset_name}_subset_acc_table.csv"), index=True)

        common_models = benchmark_scores.index.intersection(acc_table.index)

        if len(common_models) > 1:
            comp_scores = benchmark_scores.loc[common_models, "Composite"].rank(ascending=False)
            subset_accs = acc_table.loc[common_models, "subset_acc"].rank(ascending=False)
            rho, pval = spearmanr(comp_scores, subset_accs)

            corr_df = pd.DataFrame({
                "rho": [rho],
                "pval": [pval],
                "n_models": [len(common_models)]
            }, index=["Spearman_Correlation"])

            corr_df.to_csv(meme_dir.joinpath(f"{dataset_name}_subset_vs_composite_corr.csv"))


    if all_benchmark_scores_list:
        keys   = [name for name, _ in all_benchmark_scores_list]
        frames = [df.copy() for _, df in all_benchmark_scores_list]
        combined_scores = pd.concat(frames, keys=keys, names=["dataset", "model"])

        base_metrics = ['Composite', 'RKSP', 'TYDF', 'BRDF', 'UQDF', 'SP']
        cols_to_avg = (['ACC'] if 'ACC' in combined_scores.columns else []) + base_metrics
        cols_to_avg = [c for c in cols_to_avg
                    if c in combined_scores.columns and np.issubdtype(combined_scores[c].dtype, np.number)]

        model_avg_scores = (combined_scores[cols_to_avg].groupby(level="model").mean())

        add_ranks(model_avg_scores, cols=cols_to_avg, rank_method='min', suffix='_rank')

        ordered_cols = []
        for c in cols_to_avg:
            ordered_cols += [c, c + '_rank']
        model_avg_scores = model_avg_scores.reindex(columns=ordered_cols)

        if 'ACC' in model_avg_scores.columns:
            model_avg_scores = model_avg_scores.sort_values(by='ACC', ascending=False)

        avg_scores_path = meme_dir.joinpath("model_avg_benchmark_scores.csv")
        model_avg_scores.to_csv(avg_scores_path, index=True)

        corr_df = compute_metric_rank_correlations(
            model_avg_scores,
            metrics=base_metrics,      
            acc_rank_col='ACC_rank',    
            method='pearson'           
        )
        corr_path = meme_dir.joinpath("model_avg_metric_rank_corr.csv")
        corr_df.to_csv(corr_path)

        all_calculate_results_dict["_global"] = {
            "dataset": "_global",
            "matrices": {},
            "scores": {
                "model_avg_scores": model_avg_scores.copy(),
                "model_avg_corr":   corr_df.copy(),
            }
        }
        
    reason_mode_phemotypes_compare_fig, ax, share = plot_phemotype_share_pies(
        all_calculate_results_dict,
        dims=("RKSP","TYDF","BRDF","UQDF","SP"),
        title="Phemotype Composition Across Reasoning Modes (Averaged Over All Datasets)",
    )
    _save_png_and_pdf(
        reason_mode_phemotypes_compare_fig,
        fig_dir.joinpath("reasoning_mode_comparison_share.png"),
        pdf_dir.joinpath("reasoning_mode_comparison_share.pdf"),
        dpi=300,    
        bbox_inches='tight'
    )
    plt.close(reason_mode_phemotypes_compare_fig)
    
    dataset_names = [name for name in all_calculate_results_dict.keys() if name != '_global']
    if len(dataset_names) > 0:
        plot_combined_visualizations(all_calculate_results_dict, dataset_names, fig_dir, pdf_dir)

if __name__ == '__main__':
    main()
