import argparse
import json
import os
from pathlib import Path
from typing import Dict, List, Any

import math

try:
    import matplotlib.pyplot as plt
except Exception as e:
    plt = None


def _flatten_results(task: str, results: Dict[str, Dict[str, Dict[str, float]]]) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    for fs, models in (results or {}).items():
        for mname, metrics in (models or {}).items():
            row = {
                "task": task,
                "feature_set": fs,
                "model": mname,
            }
            if isinstance(metrics, dict):
                row.update(metrics)
            rows.append(row)
    return rows


def _load_json(path: Path) -> Dict[str, Any]:
    if not path.exists() or path.stat().st_size == 0:
        return {}
    with open(path, "r", encoding="utf-8") as f:
        try:
            return json.load(f)
        except Exception:
            return {}


def _ensure_dir(p: Path) -> None:
    p.mkdir(parents=True, exist_ok=True)


def _write_csv(path: Path, rows: List[Dict[str, Any]], columns: List[str]) -> None:
    import csv
    _ensure_dir(path.parent)
    with open(path, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(columns)
        for r in rows:
            w.writerow([r.get(c, "") for c in columns])


def _safe_float(x: Any, default: float = float("nan")) -> float:
    try:
        return float(x)
    except Exception:
        return default


def summarize(output_dir: str, tasks: List[str], top_n: int = 15) -> None:
    out_dir = Path(output_dir)
    summary_dir = out_dir / "summary"
    plots_dir = out_dir / "plots"
    _ensure_dir(summary_dir)
    _ensure_dir(plots_dir)

    for task in tasks:
        res_path = out_dir / f"results_{task}.json"
        results = _load_json(res_path)
        rows = _flatten_results(task, results)
        # sort by Pearson (desc)
        rows_sorted = sorted(rows, key=lambda r: _safe_float(r.get("corr_rating_pearson"), -1.0), reverse=True)

        # write full summary CSV
        if task == "regression":
            cols = [
                "task", "feature_set", "model", "rmse", "mae", "r2", "corr_rating_pearson", "corr_rating_spearman",
            ]
        else:
            cols = [
                "task", "feature_set", "model", "accuracy", "f1", "roc_auc", "pr_auc", "corr_rating_pearson", "corr_rating_spearman",
            ]
        _write_csv(summary_dir / f"summary_{task}.csv", rows_sorted, cols)

        # top per feature set
        best_by_fs: List[Dict[str, Any]] = []
        by_fs: Dict[str, List[Dict[str, Any]]] = {}
        for r in rows_sorted:
            by_fs.setdefault(r.get("feature_set", ""), []).append(r)
        for fs, frs in by_fs.items():
            if not frs:
                continue
            best_by_fs.append(frs[0])
        _write_csv(summary_dir / f"top_by_feature_set_{task}.csv", best_by_fs, cols)

        # plots
        if plt is None:
            continue
        # Top-N across all feature sets/models by Pearson
        top = rows_sorted[: min(top_n, len(rows_sorted))]
        labels = [f"{r.get('feature_set','')}/{r.get('model','')}" for r in top]
        vals = [_safe_float(r.get("corr_rating_pearson"), float("nan")) for r in top]
        plt.figure(figsize=(max(6, min(20, 0.6 * len(labels))), 4.5))
        plt.bar(range(len(labels)), vals, color="#1f77b4")
        plt.xticks(range(len(labels)), labels, rotation=60, ha="right")
        plt.ylabel("Pearson correlation with rating")
        plt.title(f"Top-{len(labels)} Pearson across feature sets/models ({task})")
        plt.tight_layout()
        plt.savefig(plots_dir / f"top{len(labels)}_pearson_{task}.png", dpi=150)
        plt.close()

        # Per-feature-set bar plots of Pearson by model
        for fs, frs in by_fs.items():
            labs = [r.get("model", "") for r in frs]
            v = [_safe_float(r.get("corr_rating_pearson"), float("nan")) for r in frs]
            plt.figure(figsize=(max(4, 1.2 * len(labs)), 3.8))
            plt.bar(range(len(labs)), v, color="#2ca02c")
            plt.xticks(range(len(labs)), labs, rotation=30, ha="right")
            plt.ylabel("Pearson correlation")
            plt.title(f"Pearson by model ({task}) - {fs}")
            plt.tight_layout()
            safe_fs = fs.replace('/', '_')
            plt.savefig(plots_dir / f"pearson_by_model_{task}_{safe_fs}.png", dpi=150)
            plt.close()

        # Pearson vs Spearman scatter
        xs = []
        ys = []
        labels = []
        for r in rows_sorted:
            p = r.get("corr_rating_pearson")
            s = r.get("corr_rating_spearman")
            if p is None or s is None:
                continue
            try:
                px = float(p)
                sy = float(s)
                if math.isfinite(px) and math.isfinite(sy):
                    xs.append(px)
                    ys.append(sy)
                    labels.append(f"{r.get('feature_set','')}/{r.get('model','')}")
            except Exception:
                pass
        if xs:
            plt.figure(figsize=(5.5, 4.5))
            plt.scatter(xs, ys, s=30, alpha=0.8)
            plt.xlabel("Pearson")
            plt.ylabel("Spearman")
            plt.title(f"Pearson vs Spearman ({task})")
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(plots_dir / f"pearson_vs_spearman_{task}.png", dpi=150)
            plt.close()


def main():
    ap = argparse.ArgumentParser(description="Summarize unified_analysis results and generate plots")
    ap.add_argument("--output_dir", default="unified_analysis/outputs")
    ap.add_argument("--tasks", default="regression,classification")
    ap.add_argument("--top_n", type=int, default=15)
    args = ap.parse_args()

    tasks = [t.strip() for t in args.tasks.split(',') if t.strip()]

    try:
        summarize(args.output_dir, tasks, top_n=args.top_n)
    except Exception as e:
        print(f"Summarization failed: {e}")
        raise


if __name__ == "__main__":
    main()
