import os
from typing import Dict, Tuple, List, Optional
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap, BoundaryNorm, LinearSegmentedColormap
import pandas as pd
import numpy as np

from .load_metrics import (
    load_all,
    ordered_seed_labels_for_quality,
    all_metric_rows,
    QUALITIES,
)

PALETTE = sns.color_palette("tab10_light")
MODEL_ORDER = ["ConvE", "TransE", "DistMult", "RGCN", "Transformer"]
DATASET_ORDER = ["WN18RR", "kinship", "nations", "codex-s"]
HEATMAP_MODEL_ORDER = ["TransE", "ConvE", "DistMult", "Transformer", "RGCN"]

# Toggles for including special seeds in best quality visualizations
SHOW_RAW = False
SHOW_CONTROL_REPROC = False
SHOW_HARDWARE = True

def _best_seed_order_filtered():
    order = ordered_seed_labels_for_quality("best")
    keep = []
    for s in order:
        if s == "Raw" and not SHOW_RAW:
            continue
        if s == "Control_Reproc" and not SHOW_CONTROL_REPROC:
            continue
        if s == "Hardware" and not SHOW_HARDWARE:
            continue
        keep.append(s)
    return keep


def _collect_flat_records(all_data) -> List[Dict[str, object]]:
    rows = []
    for (dataset, model), pair_data in all_data.items():
        for q in QUALITIES:
            for seed in ordered_seed_labels_for_quality(q):
                metrics = pair_data.get(q, {}).get(seed, {})
                for m_name, v in metrics.items():
                    rows.append({
                        "dataset": dataset,
                        "model": model,
                        "quality": q,
                        "seed": seed,
                        "metric": m_name,
                        "mean": v.get("mean"),
                        "std": v.get("std"),
                    })
    return rows


# RQ1: For each metric a groupbar.
# - Group by model
# - One bar per dataset

def make_rq1(output_dir: str = "process_data/RQ1"):
    os.makedirs(output_dir, exist_ok=True)
    all_data = load_all()
    rows = _collect_flat_records(all_data)
    df = pd.DataFrame(rows)

    # We'll do RQ1 with quality=best and seed=all (most representative), unless seed not present then skip
    filt = (df["quality"] == "best") & (df["seed"] == "all")
    dfb = df[filt]

    metrics = [m for m in all_metric_rows()]
    for metric in metrics:
        dmf = dfb[dfb["metric"] == metric].copy()
        # Coerce mean to numeric and drop NaNs
        dmf["mean"] = pd.to_numeric(dmf["mean"], errors="coerce")
        dmf = dmf.dropna(subset=["mean"]) 
        if dmf.empty:
            continue
        plt.figure(figsize=(10, 4))
        # Enforce category orders for reproducible layout/colors
        dmf["model"] = pd.Categorical(dmf["model"], categories=MODEL_ORDER, ordered=True)
        dmf["dataset"] = pd.Categorical(dmf["dataset"], categories=DATASET_ORDER, ordered=True)
        dmf.sort_values(["model", "dataset"], inplace=True)
        ax = sns.barplot(data=dmf, x="model", y="mean", hue="dataset", palette=PALETTE, errorbar=None)
        # Add error bars if std available
        dmf["std"] = pd.to_numeric(dmf["std"], errors="coerce")
        # Build stds in plotting order (model major, dataset minor)
        x_models = [c for c in MODEL_ORDER if c in dmf["model"].cat.categories]
        h_datasets = [c for c in DATASET_ORDER if c in dmf["dataset"].cat.categories]
        std_list = []
        height_list = []
        for m in x_models:
            subm = dmf[dmf["model"] == m]
            for d in h_datasets:
                row = subm[subm["dataset"] == d]
                if not row.empty:
                    std_list.append(row.iloc[0]["std"])
                    height_list.append(row.iloc[0]["mean"])
        # Iterate bars and attach errorbars
        for patch, yerr in zip(ax.patches, std_list):
            if pd.notna(yerr) and yerr > 0:
                x = patch.get_x() + patch.get_width() / 2
                y = patch.get_height()
                ax.errorbar(x, y, yerr=yerr, fmt='none', ecolor='black', capsize=3, lw=1)
        # Force scale for ambiguity/discrepancy and all jaccard-type metrics
        if ("ambiguity" in metric) or ("discrepancy" in metric) or ("jaccard" in metric):
            ax.set_ylim(0.0, 1.0)
        # No titles/legends and no axis labels
        ax.set_title("")
        ax.set_xlabel("")
        ax.set_ylabel("")
        # Print color mapping (dataset -> color)
        hue_levels = [d for d in DATASET_ORDER if d in dmf["dataset"].cat.categories]
        mapping = {lvl: PALETTE[i] for i, lvl in enumerate(hue_levels)}
        print(f"[RQ1][{metric}] color map (dataset): {mapping}")
        # Remove any residual legend and enlarge model ticks
        leg = ax.get_legend()
        if leg is not None:
            leg.remove()
        # Larger tick labels for models on x-axis
        ax.tick_params(axis='x', labelsize=14)
        # Ensure spines visible and add tiny axis margins to avoid clipping
        for s in ax.spines.values():
            s.set_visible(True)
        ax.margins(x=0.02, y=0.02)
        # Save without visible whitespace but avoid clipping the right edge
        fname = os.path.join(output_dir, f"rq1_{metric.replace('@','at').replace('/','_')}.png")
        plt.savefig(fname, bbox_inches='tight', pad_inches=0.03)
        plt.close()
        print(f"Saved {fname}")
        # Alternate plot with 1 - value for ambiguity/discrepancy
        if ("ambiguity" in metric) or ("discrepancy" in metric):
            dmf_alt = dmf.copy()
            dmf_alt["mean"] = 1.0 - dmf_alt["mean"]
            plt.figure(figsize=(10, 4))
            ax = sns.barplot(data=dmf_alt, x="model", y="mean", hue="dataset", palette=PALETTE, errorbar=None)
            # reuse same std_list order since std is invariant to 1 - x
            for patch, yerr in zip(ax.patches, std_list):
                if pd.notna(yerr) and yerr > 0:
                    x = patch.get_x() + patch.get_width() / 2
                    y = patch.get_height()
                    ax.errorbar(x, y, yerr=yerr, fmt='none', ecolor='black', capsize=3, lw=1)
            ax.set_ylim(0.0, 1.0)
            plt.title(f"RQ1 - alt(1 - {metric})")
            plt.tight_layout()
            fname = os.path.join(output_dir, f"rq1_alt_{metric.replace('@','at').replace('/','_')}.png")
            plt.savefig(fname)
            plt.close()
            print(f"Saved {fname}")


def make_rq2_variant(output_dir: str = "process_data/RQ2_variant"):
    """
    Variant of RQ2 where colors map to models and grouping is by seed type across x=model.
    One subplot per dataset in a (2,2) grid.
    """
    os.makedirs(output_dir, exist_ok=True)
    all_data = load_all()
    rows = _collect_flat_records(all_data)
    df = pd.DataFrame(rows)

    metrics = [m for m in all_metric_rows()]
    datasets = DATASET_ORDER

    def add_errorbars(ax, x_levels, hue_levels, df, x_col, hue_col, y_col="mean", err_col="std"):
        expected = []
        for xv in x_levels:
            for hv in hue_levels:
                expected.append((xv, hv))
        yerrs = []
        for xv, hv in expected:
            row = df[(df[x_col] == xv) & (df[hue_col] == hv)]
            if not row.empty:
                yerrs.append(row.iloc[0][err_col])
            else:
                yerrs.append(np.nan)
        for patch, yerr in zip(ax.patches, yerrs):
            if pd.notna(yerr) and yerr > 0:
                x = patch.get_x() + patch.get_width() / 2
                y = patch.get_height()
                ax.errorbar(x, y, yerr=yerr, fmt='none', ecolor='black', capsize=3, lw=1)

    for metric in metrics:
        for ds in datasets:
            plt.figure(figsize=(10, 5))
            sub = df[(df["dataset"] == ds) & (df["metric"] == metric) & (df["quality"] == "best")].copy()
            seed_order = _best_seed_order_filtered()
            sub["seed"] = pd.Categorical(sub["seed"], categories=seed_order, ordered=True)
            sub["model"] = pd.Categorical(sub["model"], categories=MODEL_ORDER, ordered=True)
            sub.sort_values(["model", "seed"], inplace=True)
            sub["mean"] = pd.to_numeric(sub["mean"], errors="coerce")
            sub["std"] = pd.to_numeric(sub["std"], errors="coerce")
            sub = sub.dropna(subset=["mean"]) 
            ax = sns.barplot(data=sub, x="model", y="mean", hue="seed", palette=PALETTE, errorbar=None)
            add_errorbars(ax, [m for m in MODEL_ORDER if m in sub["model"].cat.categories], list(sub["seed"].cat.categories), sub, "model", "seed")
            if "jaccard" in metric:
                ax.set_ylim(0.0, 1.0)
            if ("ambiguity" in metric) or ("discrepancy" in metric) or ("jaccard" in metric):
                ax.set_ylim(0.0, 1.0)
            # No titles/legends and no axis labels
            ax.set_title("")
            ax.set_xlabel("")
            ax.set_ylabel("")
            plt.xticks(rotation=0)
            # Print color mapping (seed -> color)
            hue_levels = list(sub["seed"].cat.categories)
            mapping = {lvl: PALETTE[i] for i, lvl in enumerate(hue_levels)}
            print(f"[RQ2_variant][{metric}][{ds}] color map (seed): {mapping}")
            # Larger tick labels for models on x-axis
            ax.tick_params(axis='x', labelsize=14)
            leg = ax.get_legend()
            if leg is not None:
                leg.remove()
            for s in ax.spines.values():
                s.set_visible(True)
            ax.margins(x=0.02, y=0.02)
            fname = os.path.join(output_dir, f"rq2_variant_{metric.replace('@','at').replace('/','_')}_{ds}.png")
            plt.savefig(fname, bbox_inches='tight', pad_inches=0.03)
            plt.close()
            print(f"Saved {fname}")
            # Alt plot 1 - value
            if ("ambiguity" in metric) or ("discrepancy" in metric):
                sub_alt = sub.copy()
                sub_alt["mean"] = 1.0 - sub_alt["mean"]
                plt.figure(figsize=(10, 5))
                ax = sns.barplot(data=sub_alt, x="model", y="mean", hue="seed", palette=PALETTE, errorbar=None)
                add_errorbars(ax, [m for m in MODEL_ORDER if m in sub_alt["model"].cat.categories], list(sub_alt["seed"].cat.categories), sub_alt, "model", "seed")
                ax.set_ylim(0.0, 1.0)
                ax.set_title("")
                ax.set_xlabel("")
                ax.set_ylabel("")
                plt.xticks(rotation=0)
                hue_levels = list(sub_alt["seed"].cat.categories)
                mapping = {lvl: PALETTE[i] for i, lvl in enumerate(hue_levels)}
                print(f"[RQ2_variant][alt(1-{metric})][{ds}] color map (seed): {mapping}")
                leg = ax.get_legend()
                if leg is not None:
                    leg.remove()
                for s in ax.spines.values():
                    s.set_visible(True)
                ax.margins(x=0.02, y=0.02)
                fname = os.path.join(output_dir, f"rq2_variant_alt_{metric.replace('@','at').replace('/','_')}_{ds}.png")
                plt.savefig(fname, bbox_inches='tight', pad_inches=0.03)
                plt.close()
                print(f"Saved {fname}")


# RQ2: For each metric a plot.
# - Subplot per dataset.
# - Group by seed_type + [Raw, Control_Reproc, Hardware]
# - One bar per model (same color per model)

def make_rq2(output_dir: str = "process_data/RQ2"):
    os.makedirs(output_dir, exist_ok=True)
    all_data = load_all()
    rows = _collect_flat_records(all_data)
    df = pd.DataFrame(rows)

    metrics = [m for m in all_metric_rows()]
    datasets = DATASET_ORDER

    def add_errorbars(ax, x_levels, hue_levels, df, x_col, hue_col, y_col="mean", err_col="std"):
        expected = []
        for xv in x_levels:
            for hv in hue_levels:
                expected.append((xv, hv))
        # map to yerr list in the same order as patches
        yerrs = []
        for xv, hv in expected:
            row = df[(df[x_col] == xv) & (df[hue_col] == hv)]
            if not row.empty:
                yerrs.append(row.iloc[0][err_col])
            else:
                yerrs.append(np.nan)
        for patch, yerr in zip(ax.patches, yerrs):
            if pd.notna(yerr) and yerr > 0:
                x = patch.get_x() + patch.get_width() / 2
                y = patch.get_height()
                ax.errorbar(x, y, yerr=yerr, fmt='none', ecolor='black', capsize=3, lw=1)

    for metric in metrics:
        for ds in datasets:
            plt.figure(figsize=(10, 5))
            sub = df[(df["dataset"] == ds) & (df["metric"] == metric) & (df["quality"] == "best")].copy()
            seed_order = _best_seed_order_filtered()
            sub["seed"] = pd.Categorical(sub["seed"], categories=seed_order, ordered=True)
            sub["model"] = pd.Categorical(sub["model"], categories=MODEL_ORDER, ordered=True)
            sub.sort_values(["seed", "model"], inplace=True)
            sub["mean"] = pd.to_numeric(sub["mean"], errors="coerce")
            sub["std"] = pd.to_numeric(sub["std"], errors="coerce")
            sub = sub.dropna(subset=["mean"]) 
            ax = sns.barplot(data=sub, x="seed", y="mean", hue="model", palette=PALETTE, errorbar=None)
            add_errorbars(ax, list(sub["seed"].cat.categories), [m for m in MODEL_ORDER if m in sub["model"].cat.categories], sub, "seed", "model")
            if ("ambiguity" in metric) or ("discrepancy" in metric):
                ax.set_ylim(0.0, 1.0)
            # No titles/legends and no axis labels
            ax.set_title("")
            ax.set_xlabel("")
            ax.set_ylabel("")
            plt.xticks(rotation=45)
            # Print color mapping (model -> color)
            hue_levels = [m for m in MODEL_ORDER if m in sub["model"].cat.categories]
            mapping = {lvl: PALETTE[i] for i, lvl in enumerate(hue_levels)}
            print(f"[RQ2][{metric}][{ds}] color map (model): {mapping}")
            leg = ax.get_legend()
            if leg is not None:
                leg.remove()
            for s in ax.spines.values():
                s.set_visible(True)
            ax.margins(x=0.02, y=0.02)
            fname = os.path.join(output_dir, f"rq2_{metric.replace('@','at').replace('/','_')}_{ds}.png")
            plt.savefig(fname, bbox_inches='tight', pad_inches=0.01)
            plt.close()
            print(f"Saved {fname}")
            # Alt plot 1 - value
            if ("ambiguity" in metric) or ("discrepancy" in metric):
                sub_alt = sub.copy()
                sub_alt["mean"] = 1.0 - sub_alt["mean"]
                plt.figure(figsize=(10, 5))
                ax = sns.barplot(data=sub_alt, x="seed", y="mean", hue="model", palette=PALETTE, errorbar=None)
                add_errorbars(ax, list(sub_alt["seed"].cat.categories), [m for m in MODEL_ORDER if m in sub_alt["model"].cat.categories], sub_alt, "seed", "model")
                ax.set_ylim(0.0, 1.0)
                ax.set_title("")
                ax.set_xlabel("")
                ax.set_ylabel("")
                plt.xticks(rotation=45)
                hue_levels = [m for m in MODEL_ORDER if m in sub_alt["model"].cat.categories]
                mapping = {lvl: PALETTE[i] for i, lvl in enumerate(hue_levels)}
                print(f"[RQ2][alt(1-{metric})][{ds}] color map (model): {mapping}")
                leg = ax.get_legend()
                if leg is not None:
                    leg.remove()
                for s in ax.spines.values():
                    s.set_visible(True)
                ax.margins(x=0.02, y=0.02)
                fname = os.path.join(output_dir, f"rq2_alt_{metric.replace('@','at').replace('/','_')}_{ds}.png")
                plt.savefig(fname, bbox_inches='tight', pad_inches=0.01)
                plt.close()
                print(f"Saved {fname}")


# RQ3: per metric heatmap
# - x: [best, median, worst]
# - y: (model, dataset)
# We use seed=all for all qualities as per spec for median/worst. For best, we also use seed=all for comparability.

def make_rq3(output_dir: str = "process_data/RQ3"):
    os.makedirs(output_dir, exist_ok=True)
    all_data = load_all()
    rows = _collect_flat_records(all_data)
    df = pd.DataFrame(rows)

    metrics = [m for m in all_metric_rows()]
    # Select seed=all for all qualities
    dfa = df[(df["seed"] == "all") & (df["quality"].isin(["best", "median", "worst"]))].copy()

    # build index (model, dataset)
    pairs = sorted(dfa[["model", "dataset"]].drop_duplicates().itertuples(index=False, name=None))

    for metric in metrics:
        sub = dfa[dfa["metric"] == metric].copy()
        if sub.empty:
            continue
        # pivot to (model,dataset) x quality
        sub["pair"] = sub["model"] + " | " + sub["dataset"]
        sub["mean"] = pd.to_numeric(sub["mean"], errors="coerce")
        pivot = sub.pivot_table(index="pair", columns="quality", values="mean")
        # Reorder rows to satisfy model priority for y-axis
        def pair_key(p):
            model = p.split(" | ")[0]
            return HEATMAP_MODEL_ORDER.index(model) if model in HEATMAP_MODEL_ORDER else len(HEATMAP_MODEL_ORDER)
        pivot = pivot.reindex(sorted(pivot.index, key=pair_key))
        # ensure column order
        pivot = pivot.reindex(columns=["best", "median", "worst"])
        plt.figure(figsize=(6, max(4, 0.4 * len(pivot))))
        # Fixed continuous gradient for jaccard-like metrics with absolute scale [0,1]
        if "jaccard" in metric:
            # Anchored continuous colormap
            positions = [0.0, 0.5, 0.85, 1.0]
            colors = ["darkred", "orange", "yellow", "green"]
            cmap = LinearSegmentedColormap.from_list("jaccard_fixed", list(zip(positions, colors)))
            sns.heatmap(pivot, annot=True, fmt=".3f", cmap=cmap, vmin=0.0, vmax=1.0)
        else:
            sns.heatmap(pivot, annot=True, fmt=".3f", cmap="RdYlGn", cbar=False)
        # No titles and clear axis labels
        plt.title("")
        plt.ylabel("")
        fname = os.path.join(output_dir, f"rq3_{metric.replace('@','at').replace('/','_')}.png")
        plt.savefig(fname, bbox_inches='tight', pad_inches=0.01)
        plt.close()
        print(f"Saved {fname}")



def make_rq3_variant(output_dir: str = "process_data/RQ3_variant"):
    os.makedirs(output_dir, exist_ok=True)
    all_data = load_all()
    rows = _collect_flat_records(all_data)
    df = pd.DataFrame(rows)

    metrics = [m for m in all_metric_rows()]
    # Select seed=all and only best/median/worst and filter datasets to exclude kinship and nations
    dfa = df[(df["seed"] == "all") & (df["quality"].isin(["best", "median", "worst"]))].copy()
    dfa = dfa[dfa["dataset"].isin(["WN18RR", "codex-s"])]

    for metric in metrics:
        sub = dfa[dfa["metric"] == metric].copy()
        if sub.empty:
            continue
        sub["pair"] = sub["model"] + " | " + sub["dataset"]
        sub["mean"] = pd.to_numeric(sub["mean"], errors="coerce")
        pivot = sub.pivot_table(index="pair", columns="quality", values="mean")
        def pair_key(p):
            model = p.split(" | ")[0]
            return HEATMAP_MODEL_ORDER.index(model) if model in HEATMAP_MODEL_ORDER else len(HEATMAP_MODEL_ORDER)
        pivot = pivot.reindex(sorted(pivot.index, key=pair_key))
        pivot = pivot.reindex(columns=["best", "median", "worst"])
        plt.figure(figsize=(6, max(4, 0.4 * len(pivot))))
        if "jaccard" in metric:
            positions = [0.0, 0.5, 0.85, 1.0]
            colors = ["darkred", "orange", "yellow", "green"]
            cmap = LinearSegmentedColormap.from_list("jaccard_fixed", list(zip(positions, colors)))
            sns.heatmap(pivot, annot=True, fmt=".3f", cmap=cmap, vmin=0.0, vmax=1.0)
        else:
            sns.heatmap(pivot, annot=True, fmt=".3f", cmap="RdYlGn")
        plt.title("")
        plt.xlabel("")
        plt.ylabel("")
        fname = os.path.join(output_dir, f"rq3_variant_{metric.replace('@','at').replace('/','_')}.png")
        plt.savefig(fname, bbox_inches='tight', pad_inches=0.01)
        plt.close()
        print(f"Saved {fname}")


if __name__ == "__main__":
    make_rq1()
    make_rq2()
    make_rq2_variant()
    make_rq3()
    make_rq3_variant()