import os, re, argparse
import pandas as pd
import matplotlib.pyplot as plt


# =============== Utilities ===============

def _sniff_delimiter(csv_path: str) -> str:
    """自动识别分隔符"""
    with open(csv_path, "r", encoding="utf-8", errors="ignore") as f:
        line = f.readline()
    if "\t" in line: return "\t"
    if ";" in line: return ";"
    return ","


def _extract_piece_label(name: str) -> str:
    """从文件名里提取 piece 类型"""
    s = str(name).lower()
    if "square" in s: return "square"
    if "parallelogram" in s: return "parallelogram"
    if "small_triangle" in s: return "small_triangle"
    if "medium_triangle" in s: return "medium_triangle"
    if "big_triangle" in s or "large_triangle" in s: return "big_triangle"
    return "unknown"


def _detect_columns_for_iou(csv_path: str, sep: str):
    """检测 iou / name / piece 列"""
    df_head = pd.read_csv(csv_path, sep=sep, nrows=20, on_bad_lines="skip")
    cols = [c.strip().lower() for c in df_head.columns]

    iou_col = None
    name_col = None
    piece_col = None
    for c in df_head.columns:
        cl = c.strip().lower()
        if iou_col is None and ("iou" in cl or cl in ("value", "score", "iou_score")):
            iou_col = c
        if name_col is None and cl in ("name", "id", "sample", "filename"):
            name_col = c
        if piece_col is None and "piece" in cl:
            piece_col = c
    if name_col is None:
        name_col = df_head.columns[0]
    if iou_col is None:
        if len(df_head.columns) >= 2:
            iou_col = df_head.columns[1]
        else:
            raise ValueError("Cannot find IoU column")

    return iou_col, name_col, piece_col


# =============== 功能3：按 piece 聚合 ===============

def mean_iou_by_piece_from_csv(csv_path: str,
                               title: str = "Mean IoU by Piece Type",
                               save_csv: str | None = None,
                               save_png: str | None = None,
                               big_font: bool = False) -> pd.DataFrame:
    if not os.path.isfile(csv_path):
        raise FileNotFoundError(csv_path)

    sep = _sniff_delimiter(csv_path)
    iou_col, name_col, piece_col = _detect_columns_for_iou(csv_path, sep)

    sums, cnts = {}, {}
    usecols = [c for c in (name_col, iou_col, piece_col) if c is not None]
    reader = pd.read_csv(csv_path, sep=sep, usecols=usecols,
                         chunksize=200_000, on_bad_lines="skip", low_memory=True)

    for chunk in reader:
        vals = pd.to_numeric(chunk[iou_col], errors="coerce")
        names = chunk[name_col].astype(str)
        if piece_col:
            pieces = chunk[piece_col].astype(str).str.lower().map(lambda x: re.sub(r"[\\s\\-]+", "_", x))
        else:
            pieces = names.map(_extract_piece_label)
        mask = vals.notna()
        for p, v in zip(pieces[mask], vals[mask]):
            key = p if p else "unknown"
            sums[key] = sums.get(key, 0.0) + float(v)
            cnts[key] = cnts.get(key, 0) + 1

    data = [{"piece": p, "count": c, "mean_iou": sums[p] / c} for p, c in cnts.items()]
    df = pd.DataFrame(data).sort_values("mean_iou", ascending=False).reset_index(drop=True)
    # Instead of removing "unknown", rename it to "average"
    df["piece"] = df["piece"].replace("unknown", "average")

    if save_csv: df.to_csv(save_csv, index=False); print(f"[SAVE] {save_csv}")
    fs_title = 18 if big_font else 12
    fs_label = 14 if big_font else 11
    fs_tick  = 12 if big_font else 10
    plt.figure(figsize=(10, 5))
    bars = plt.bar(df["piece"], df["mean_iou"], edgecolor="black")
    for b, v in zip(bars, df["mean_iou"]):
        plt.text(b.get_x()+b.get_width()/2, v+0.01, f"{v:.3f}", ha="center", va="bottom", fontsize=fs_tick)
    plt.ylabel("Mean IoU", fontsize=fs_label); plt.ylim(0, 1.0); plt.title(title, fontsize=fs_title)
    plt.xticks(rotation=30, ha="right", fontsize=fs_tick)
    plt.tight_layout()
    plt.show()
    return df


# =============== 功能1：单 CSV 总体统计 ===============

def summarize_iou_csv(csv_path: str,
                      save_csv: str | None = None,
                      save_png: str | None = None) -> dict:
    df = pd.read_csv(csv_path, engine="python", sep=None, on_bad_lines="skip")
    iou_col = next((c for c in df.columns if "iou" in c.lower() or c.lower() in ("value", "score", "iou_score")), df.columns[1])
    vals = pd.to_numeric(df[iou_col], errors="coerce").dropna()
    stats = {
        "path": os.path.abspath(csv_path), "count": int(vals.size),
        "mean": float(vals.mean()), "median": float(vals.median()),
        "std": float(vals.std(ddof=1)) if vals.size > 1 else 0.0,
        "min": float(vals.min()), "max": float(vals.max())
    }
    if save_csv: pd.DataFrame([stats]).to_csv(save_csv, index=False); print(f"[SAVE] {save_csv}")
    plt.figure(figsize=(4, 5))
    bar = plt.bar(["mean"], [stats["mean"]], edgecolor="black")[0]
    plt.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.02, f"{stats['mean']:.3f}", ha="center")
    plt.ylabel("Mean IoU"); plt.ylim(0, 1.0); plt.title("IoU Summary")
    plt.tight_layout()
    plt.show()
    return stats


# =============== 功能2：多 CSV 对比 ===============


def compare_csv_means(csv_paths: list[str],
                      labels: list[str] | None = None,
                      save_csv: str | None = None,
                      save_png: str | None = None) -> pd.DataFrame:
    rows = []
    for p in csv_paths:
        df = pd.read_csv(p, engine="python", sep=None, on_bad_lines="skip")
        iou_col = next((c for c in df.columns if "iou" in c.lower()), df.columns[1])
        vals = pd.to_numeric(df[iou_col], errors="coerce").dropna()
        rows.append({"path": os.path.abspath(p), "count": len(vals), "mean_iou": float(vals.mean())})
    df = pd.DataFrame(rows)
    df.insert(0, "label", labels if labels else [os.path.basename(os.path.dirname(p)) for p in df["path"]])
    if save_csv: df.to_csv(save_csv, index=False); print(f"[SAVE] {save_csv}")
    plt.figure(figsize=(max(6, 1.8*len(df)), 5))
    bars = plt.bar(df["label"], df["mean_iou"], edgecolor="black")
    for b, v in zip(bars, df["mean_iou"]):
        plt.text(b.get_x()+b.get_width()/2, v+0.02, f"{v:.3f}", ha="center")
    plt.ylabel("Mean IoU"); plt.ylim(0, 1.0); plt.title("Mean IoU Comparison")
    plt.xticks(rotation=20, ha="right")
    plt.tight_layout()
    plt.show()
    return df

# =============== 功能4：实验折线图（横轴为参数标签，纵轴为 Mean IoU） ===============

def plot_iou_line(labels: list[str], values: list[float],
                  title: str = "Mean IoU by Experiment",
                  save_csv: str | None = None,
                  save_png: str | None = None,
                  big_font: bool = False) -> pd.DataFrame:
    if len(labels) != len(values) or len(labels) == 0:
        raise ValueError("labels 与 values 长度必须一致且非空")
    df = pd.DataFrame({"label": labels, "mean_iou": values})
    if save_csv:
        df.to_csv(save_csv, index=False)
        print(f"[SAVE] {save_csv}")

    fs_title = 18 if big_font else 12
    fs_label = 14 if big_font else 11
    fs_tick  = 12 if big_font else 10

    plt.figure(figsize=(max(7, 2.0*len(labels)), 5.0 if big_font else 4.5))
    plt.plot(range(len(labels)), values, marker="o")
    plt.xticks(ticks=range(len(labels)), labels=labels, rotation=25, ha="right", fontsize=fs_tick)
    plt.yticks(fontsize=fs_tick)
    # ---- 固定 y 轴范围为 0.5 到 1.0 ----
    plt.ylim(0.5, 1.0)

    # ---- 标注做截断，避免贴到上框 ----
    for i, v in enumerate(values):
        y_anno = min(v + 0.015, 1.0 - 0.02)
        plt.text(i, y_anno, f"{v:.3f}", ha="center", va="bottom", fontsize=fs_tick)
    plt.ylabel("Mean IoU", fontsize=fs_label)
    plt.title(title, fontsize=fs_title)
    plt.grid(True, linestyle=":", linewidth=0.7)
    plt.tight_layout()
    if save_png:
        plt.savefig(save_png, dpi=180)
        print(f"[SAVE] {save_png}")
    plt.show()
    return df


# =============== Main CLI ===============

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="IoU tools: (1) single CSV stats; (2) compare CSV means; (3) per-piece stats")
    parser.add_argument("--csv_stats", type=str)
    parser.add_argument("--stats_out_csv", type=str)
    parser.add_argument("--stats_out_png", type=str)
    parser.add_argument("--compare_csvs", type=str, nargs="+")
    parser.add_argument("--compare_labels", type=str)
    parser.add_argument("--compare_out_csv", type=str)
    parser.add_argument("--compare_out_png", type=str)
    parser.add_argument("--piece_csv", type=str)
    parser.add_argument("--piece_out_csv", type=str)
    parser.add_argument("--piece_out_png", type=str)
    parser.add_argument("--exp_labels", type=str, help="逗号分隔的实验标签（横轴，例如：baseline,icl+loop,thr0.95,loop8，可重复增加如icl+loop=0.932）")
    parser.add_argument("--exp_ious", type=str, help="逗号分隔的 Mean IoU 值（纵轴，例如：0.65,0.72,0.74,0.92）")
    parser.add_argument("--line_out_csv", type=str, help="折线图数据导出 CSV 路径")
    parser.add_argument("--line_out_png", type=str, help="折线图图片导出 PNG 路径")
    parser.add_argument("--big_font", action="store_true", help="放大坐标轴与标题字体，适合论文/展示图")
    args = parser.parse_args()

    if args.csv_stats:
        base = os.path.dirname(args.csv_stats) or "."
        summarize_iou_csv(args.csv_stats,
                          save_csv=args.stats_out_csv or os.path.join(base, "iou_stats.csv"),
                          save_png=args.stats_out_png or os.path.join(base, "iou_stats.png"))
    elif args.compare_csvs:
        labels = [s.strip() for s in args.compare_labels.split(",")] if args.compare_labels else None
        compare_csv_means(args.compare_csvs, labels=labels,
                          save_csv=args.compare_out_csv or "mean_compare.csv",
                          save_png=args.compare_out_png or "mean_compare.png")
    elif args.piece_csv:
        base = os.path.dirname(args.piece_csv) or "."
        mean_iou_by_piece_from_csv(args.piece_csv,
                                   save_csv=args.piece_out_csv or os.path.join(base, "piece_mean_iou.csv"),
                                   save_png=args.piece_out_png or os.path.join(base, "piece_mean_iou.png"),
                                   big_font=args.big_font)
    elif args.exp_labels and args.exp_ious:
        labels = [s.strip() for s in args.exp_labels.split(",") if s.strip()]
        values = [float(s) for s in args.exp_ious.split(",") if s.strip()]
        base = "."
        save_csv = args.line_out_csv or "exp_iou_line.csv"
        save_png = args.line_out_png or "exp_iou_line.png"
        plot_iou_line(labels, values,
                      title="Mean IoU by Experiment",
                      save_csv=save_csv,
                      save_png=save_png,
                      big_font=args.big_font)
    else:
        parser.print_help()