import os
import re
import glob
import argparse
import logging
from dataclasses import dataclass
from typing import List, Optional, Dict, Tuple
import yaml
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

@dataclass
class Config:
    models: List[str]
    judges: List[str]
    csv_root: str
    recursive: bool
    filters: List[str]
    skip_unknown_language: bool
    languages: List[Dict]
    tie_handling: str
    output_dir: str
    dpi: int
    style: str
    per_language: bool
    palette: str
    log_level: str

COL_PATTERN = re.compile(r"^answer_(?P<m1>[A-Za-z0-9]+)_AND_answer_(?P<m2>[A-Za-z0-9]+)_comparison_by_(?P<judge>[A-Za-z0-9]+)_judge_winner$")

def load_config(path: str) -> Config:
    with open(path, "r", encoding="utf-8") as f:
        raw = yaml.safe_load(f)
    return Config(
        models=raw["evaluation"]["models"],
        judges=raw["evaluation"]["judges"],
        csv_root=raw["data"]["csv_root"],
        recursive=raw["data"].get("recursive", True),
        filters=raw["data"].get("include_only_filename_substrings", []),
        skip_unknown_language=raw["data"].get("skip_unknown_language", True),
        languages=raw.get("languages", []),
        tie_handling=raw["comparative"].get("tie_handling", "half"),
        output_dir=raw["plots"]["output_dir"],
        dpi=raw["plots"].get("dpi", 150),
        style=raw["plots"].get("style", "whitegrid"),
        per_language=raw["plots"].get("per_language", True),
        palette=raw["plots"].get("palette", "muted"),
        log_level=raw["logging"].get("level", "INFO")
    )

def infer_language(filename: str, lang_cfg: List[Dict], skip_unknown: bool) -> Optional[str]:
    lower = os.path.basename(filename).lower()
    for entry in lang_cfg:
        for sub in entry["match_substrings"]:
            if sub.lower() in lower:
                return entry["name"]
    if skip_unknown:
        return None
    return "unknown"

def discover_csv_files(root: str, recursive: bool, filters: List[str]) -> List[str]:
    pattern = "**/*.csv" if recursive else "*.csv"
    files = glob.glob(os.path.join(root, pattern), recursive=recursive)
    if filters:
        fl = [f.lower() for f in filters]
        files = [f for f in files if any(s in os.path.basename(f).lower() for s in fl)]
    return sorted(files)

def find_comparative_columns(df: pd.DataFrame) -> List[Tuple[str, str, str, str]]:
    found = []
    for col in df.columns:
        m = COL_PATTERN.match(col)
        if m:
            g = m.groupdict()
            found.append((col, g["m1"], g["m2"], g["judge"]))
    return found

def process_file(path: str, cfg: Config) -> pd.DataFrame:
    try:
        df = pd.read_csv(path, encoding="utf-8")
    except Exception:
        return pd.DataFrame()
    cols = find_comparative_columns(df)
    if not cols:
        return pd.DataFrame()
    lang = infer_language(path, cfg.languages, cfg.skip_unknown_language)
    if lang is None:
        return pd.DataFrame()
    rows = []
    rel = os.path.relpath(path, cfg.csv_root)
    for (col, m1, m2, judge) in cols:
        if judge not in cfg.judges:
            continue
        if m1 not in cfg.models or m2 not in cfg.models:
            continue
        series = df[col]
        for idx, val in series.items():
            if pd.isna(val):
                continue
            val_str = str(val).strip()
            winner = None
            tie = False
            if val_str.lower() in ["tie", "draw"]:
                tie = True
            elif val_str.startswith("answer_"):
                winner = val_str.replace("answer_", "")
            else:
                continue
            rows.append({
                "source_file": rel,
                "row_index": idx,
                "language": lang,
                "judge": judge,
                "model_a": m1,
                "model_b": m2,
                "winner": winner,
                "is_tie": tie
            })
    return pd.DataFrame(rows)

def build_events(cfg: Config) -> pd.DataFrame:
    files = discover_csv_files(cfg.csv_root, cfg.recursive, cfg.filters)
    if not files:
        return pd.DataFrame()
    parts = []
    for fp in files:
        part = process_file(fp, cfg)
        if not part.empty:
            parts.append(part)
    if not parts:
        return pd.DataFrame()
    events = pd.concat(parts, ignore_index=True)
    return events

def pairwise_stats(events: pd.DataFrame, cfg: Config, per_language: bool=False) -> pd.DataFrame:
    if events.empty:
        return pd.DataFrame()
    group_fields = ["judge","model_a","model_b"]
    if per_language:
        group_fields = ["judge","language","model_a","model_b"]
    rows = []
    for keys, grp in events.groupby(group_fields):
        d = dict(zip(group_fields, keys))
        total = len(grp)
        a_wins = (grp["winner"] == grp["model_a"]).sum()
        b_wins = (grp["winner"] == grp["model_b"]).sum()
        ties = grp["is_tie"].sum()
        if cfg.tie_handling == "half":
            eff_total = total
            a_score = a_wins + 0.5*ties
            b_score = b_wins + 0.5*ties
        elif cfg.tie_handling == "ignore":
            eff_total = total - ties if total - ties > 0 else np.nan
            a_score = a_wins
            b_score = b_wins
        else:
            eff_total = total
            a_score = a_wins
            b_score = b_wins
        rows.append({
            **d,
            "matches_total": total,
            "a_wins": a_wins,
            "b_wins": b_wins,
            "ties": ties,
            "a_win_rate": a_score / eff_total if eff_total and not np.isnan(eff_total) else np.nan,
            "b_win_rate": b_score / eff_total if eff_total and not np.isnan(eff_total) else np.nan
        })
    return pd.DataFrame(rows)

def model_winshare(events: pd.DataFrame, cfg: Config, per_language: bool=False) -> pd.DataFrame:
    if events.empty:
        return pd.DataFrame()
    rows = []
    for _, r in events.iterrows():
        participants = [r.model_a, r.model_b]
        for m in participants:
            win = (r.winner == m)
            loss = (r.winner is not None and r.winner != m)
            if r.is_tie:
                win = False
                loss = False
            rows.append({
                "judge": r.judge,
                "language": r.language,
                "model": m,
                "is_win": win,
                "is_loss": loss,
                "is_tie": r.is_tie
            })
    df = pd.DataFrame(rows)
    group_fields = ["judge","model"] if not per_language else ["judge","language","model"]
    out = []
    for keys, grp in df.groupby(group_fields):
        d = dict(zip(group_fields, keys))
        wins = grp.is_win.sum()
        ties = grp.is_tie.sum()
        matches = len(grp)
        if cfg.tie_handling == "half":
            effective = matches
            score = wins + 0.5*ties
        elif cfg.tie_handling == "ignore":
            effective = matches - ties if matches - ties > 0 else np.nan
            score = wins
        else:
            effective = matches
            score = wins
        win_rate = score / effective if effective and not np.isnan(effective) else np.nan
        out.append({**d,
                    "wins": wins,
                    "ties": ties,
                    "matches": matches,
                    "win_rate": win_rate})
    return pd.DataFrame(out)

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def plot_overall_winrate(winshare: pd.DataFrame, cfg: Config, out_path: str):
    if winshare.empty:
        return
    plt.figure(figsize=(1.6 + 2.2*len(cfg.models), 4.5))
    sns.barplot(data=winshare, x="model", y="win_rate", hue="judge", palette=cfg.palette, errorbar=None)
    plt.ylim(0, 1)
    for p in plt.gca().patches:
        h = p.get_height()
        if not np.isnan(h):
            plt.gca().annotate(f"{h:.2f}", (p.get_x()+p.get_width()/2, h), ha="center", va="bottom", fontsize=9, xytext=(0,2), textcoords="offset points")
    plt.title("Overall Model Win Rate by Judge")
    plt.ylabel("Win rate")
    plt.xlabel("Model")
    plt.tight_layout()
    plt.savefig(out_path, dpi=cfg.dpi)
    plt.close()

def plot_overall_winrate_language(winshare_lang: pd.DataFrame, cfg: Config, out_path: str):
    if winshare_lang.empty:
        return
    g = sns.catplot(data=winshare_lang, x="model", y="win_rate", hue="judge", col="language", kind="bar", palette=cfg.palette, sharey=True, col_wrap=3)
    g.set(ylim=(0,1))
    for ax in g.axes.flatten():
        for p in ax.patches:
            h = p.get_height()
            if not np.isnan(h):
                ax.annotate(f"{h:.2f}", (p.get_x()+p.get_width()/2, h), ha="center", va="bottom", fontsize=8, xytext=(0,2), textcoords="offset points")
    g.fig.subplots_adjust(top=0.85)
    g.fig.suptitle("Model Win Rate by Judge & Language")
    g.savefig(out_path, dpi=cfg.dpi)
    plt.close(g.fig)

def plot_pairwise_heatmaps(pairwise_df: pd.DataFrame, cfg: Config, out_dir: str):
    if pairwise_df.empty:
        return
    for judge in sorted(pairwise_df.judge.unique()):
        sub = pairwise_df[pairwise_df.judge == judge]
        models = cfg.models
        mat = pd.DataFrame(index=models, columns=models, dtype=float)
        for _, r in sub.iterrows():
            a = r.model_a
            b = r.model_b
            mat.loc[a,b] = r.a_win_rate
            mat.loc[b,a] = r.b_win_rate
        for m in models:
            mat.loc[m,m] = 0.5
        plt.figure(figsize=(1+1.1*len(models), 1+0.9*len(models)))
        sns.heatmap(mat.astype(float), annot=True, fmt=".3f", cmap="viridis", vmin=0, vmax=1)
        plt.title(f"Pairwise Win Rates - Judge: {judge}")
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, f"pairwise_heatmap_{judge}.png"), dpi=cfg.dpi)
        plt.close()

def run(cfg: Config):
    sns.set_style(cfg.style)
    ensure_dir(cfg.output_dir)
    events = build_events(cfg)
    if events.empty:
        return
    pairwise_overall = pairwise_stats(events, cfg, per_language=False)
    pairwise_overall.to_csv(os.path.join(cfg.output_dir, "pairwise_results.csv"), index=False)
    winshare_overall = model_winshare(events, cfg, per_language=False)
    winshare_overall.to_csv(os.path.join(cfg.output_dir, "overall_model_winshare.csv"), index=False)
    plot_overall_winrate(winshare_overall, cfg, os.path.join(cfg.output_dir, "overall_winrate_bars.png"))
    plot_pairwise_heatmaps(pairwise_overall, cfg, cfg.output_dir)
    if cfg.per_language:
        pairwise_lang = pairwise_stats(events, cfg, per_language=True)
        pairwise_lang.to_csv(os.path.join(cfg.output_dir, "pairwise_results_by_language.csv"), index=False)
        winshare_lang = model_winshare(events, cfg, per_language=True)
        winshare_lang.to_csv(os.path.join(cfg.output_dir, "overall_model_winshare_by_language.csv"), index=False)
        plot_overall_winrate_language(winshare_lang, cfg, os.path.join(cfg.output_dir, "overall_winrate_bars_by_language.png"))

def parse_args():
    ap = argparse.ArgumentParser(description="Comparative evaluation plotting.")
    ap.add_argument("--config", required=True, help="Path to config YAML")
    return ap.parse_args()

def setup_logging(level: str):
    logging.basicConfig(level=getattr(logging, level.upper(), logging.INFO), format="%(asctime)s %(levelname)s %(message)s")

def main():
    args = parse_args()
    cfg = load_config(args.config)
    setup_logging(cfg.log_level)
    run(cfg)

if __name__ == "__main__":
    main()