import argparse
import json
from pathlib import Path
from typing import List, Dict

import pandas as pd

from unified_analysis.features import read_csv, build_feature_sets, Dataset, _weight_from_row
from unified_analysis.models import make_models
from unified_analysis.evaluate import evaluate_regression, evaluate_classification, _fit_with_sample_weight
from sklearn.metrics import roc_auc_score
from sklearn.utils.class_weight import compute_sample_weight


def main():
    p = argparse.ArgumentParser(description="Train/evaluate simple models on unified features")
    p.add_argument("--input", required=True, help="Path to present_conversations_aggregated.csv")
    p.add_argument("--task", choices=["regression", "classification"], default="regression")
    p.add_argument("--feature_sets", default="first_static,agg_static,first_flags,agg_flags,numeric_only,all")
    p.add_argument("--models", default="auto", help="Comma list or 'auto' to pick defaults per task")
    p.add_argument("--cv_splits", type=int, default=5)
    p.add_argument("--output_dir", default="unified_analysis/outputs")
    p.add_argument("--use_top_k_by_corr", type=int, default=10, help="Restrict features within each set to top-K by absolute correlation; expects correlations_numeric.csv in output_dir (default: 10)")
    p.add_argument("--seed", type=int, default=42, help="Random seed used for CV splits and model initialization")

    p.add_argument("--save_oof", action="store_true", help="Save out-of-fold predictions to CSV per fs/model")
    p.add_argument("--tune", action="store_true", help="Hyperparameter tuning to optimize corr_rating_pearson")
    args = p.parse_args()

    rows = read_csv(args.input)
    datasets = build_feature_sets(rows)

    # Filter to requested feature sets if provided
    selected_fs = {s.strip() for s in args.feature_sets.split(',') if s.strip()}
    if selected_fs:
        datasets = {k: v for k, v in datasets.items() if k in selected_fs}

    if args.models == "auto":
        model_names = list(make_models(args.task, random_state=args.seed).keys())
    else:
        model_names = [m.strip() for m in args.models.split(',') if m.strip()]

    out_dir = Path(args.output_dir)

    # Optionally restrict to top-K features by absolute Pearson correlation
    top_k = args.use_top_k_by_corr
    if top_k and top_k > 0:
        corr_path = out_dir / "correlations_numeric.csv"
        if not corr_path.exists():
            try:
                # Generate correlations if missing
                from unified_analysis import correlate_features as _corr
                _corr.run(args.input, str(out_dir))
            except Exception as e:
                print(f"Warning: could not generate correlations file automatically: {e}")
        try:
            df_corr = pd.read_csv(corr_path)
            if "pearson_r" in df_corr.columns and "feature" in df_corr.columns:
                df_corr = df_corr.dropna(subset=["pearson_r", "feature"])  # drop NaNs
                df_corr["abs_r"] = df_corr["pearson_r"].abs()
                df_corr = df_corr.sort_values("abs_r", ascending=False)
                top_features = set(df_corr["feature"].head(int(top_k)).tolist())
                print(f"Using top {len(top_features)} features by |pearson_r| from {corr_path.name}:")
                print(sorted(top_features))
                # Filter each dataset in place
                for name, ds in datasets.items():
                    ds.feature_names = [f for f in ds.feature_names if f in top_features]
                    ds.X = [
                        {k: v for k, v in x.items() if k in top_features}
                        for x in ds.X
                    ]
                # Drop datasets that ended up with no features
                dropped = [name for name, ds in datasets.items() if not ds.feature_names]
                if dropped:
                    print(f"Pruning empty feature sets after top-K filter: {dropped}")
                    datasets = {k: v for k, v in datasets.items() if v.feature_names}
            else:
                print(f"Warning: correlations file {corr_path} missing required columns; skipping top-K filter")
        except FileNotFoundError:
            print(f"Warning: correlations file {corr_path} not found; skipping top-K filter")
        except Exception as e:
            print(f"Warning: failed to apply top-K filter due to error: {e}")

    # Build special feature set: git actions + user_message_count + task type + top 10 flags by correlation
    try:
        corr_flags_path = out_dir / "correlations_flags.csv"
        if not corr_flags_path.exists():
            try:
                from unified_analysis import correlate_features as _corr
                _corr.run(args.input, str(out_dir))
            except Exception as e:
                print(f"Warning: could not generate correlations (flags) automatically: {e}")
        top_flags = []
        try:
            df_flags = pd.read_csv(corr_flags_path)
            if "feature" in df_flags.columns and "pearson_r" in df_flags.columns:
                df_flags = df_flags.dropna(subset=["feature", "pearson_r"]).copy()
                df_flags["abs_r"] = df_flags["pearson_r"].abs()
                df_flags = df_flags.sort_values("abs_r", ascending=False)
                top_flags = df_flags["feature"].head(10).tolist()
                print(f"Top {len(top_flags)} flags by |pearson_r|: {top_flags}")
            else:
                print(f"Warning: {corr_flags_path} missing required columns; falling back to first 10 flag columns")
        except Exception as e:
            print(f"Warning: failed to read {corr_flags_path}: {e}")
        all_cols = list(rows[0].keys()) if rows else []
        if not top_flags:
            top_flags = [c for c in all_cols if "flag_" in c][:10]
        git_cols = [c for c in all_cols if c.startswith("git_")]
        task_cols = [c for c in ("jsonl_first_task_type", "jsonl_agg_task_type_mode") if c in all_cols]
        base_cols = git_cols + task_cols + (["user_message_count"] if "user_message_count" in all_cols else [])
        selected_cols = list(dict.fromkeys(base_cols + top_flags))  # dedupe, preserve order
        # Construct dataset
        X = []
        y = []
        w = []
        for r in rows:
            t = (r.get("average_rating") or "").strip()
            if not t:
                continue
            try:
                ty = float(t)
            except Exception:
                continue
            fx = {}
            for c in selected_cols:
                v = r.get(c)
                if v is None:
                    continue
                if isinstance(v, str):
                    v = v.strip()
                    if v.isdigit():
                        fx[c] = int(v)
                        continue
                if c.endswith("_count"):
                    try:
                        fx[c] = float(v)
                        continue
                    except Exception:
                        pass
                try:
                    fx[c] = float(v)
                except Exception:
                    fx[c] = v
            X.append(fx)
            y.append(ty)
            w.append(_weight_from_row(r))
        if selected_cols:
            datasets["git_user_task_top10flags"] = Dataset(X=X, y=y, feature_names=selected_cols, w=w)
            print(f"Added feature set 'git_user_task_top10flags' with {len(selected_cols)} features: {selected_cols}")
        else:
            print("Warning: no columns found for git_task_topflags; skipping")
    except Exception as e:
        print(f"Warning: failed to build git_task_topflags feature set: {e}")

    out_dir.mkdir(parents=True, exist_ok=True)

    # Ensure feature set filtering still applies after dynamic additions above
    if selected_fs:
        datasets = {k: v for k, v in datasets.items() if k in selected_fs}


    results: Dict[str, Dict[str, Dict[str, float]]] = {}

    for fs_name, ds in datasets.items():
        if not ds.X:
            continue
        results[fs_name] = {}
        for mname in model_names:
            models = make_models(args.task, random_state=args.seed)
            if mname not in models:
                continue
            model = models[mname]
            if args.tune:
                try:
                    model = _tune_model(args.task, mname, model, ds, args.cv_splits, out_dir)
                except Exception as e:
                    print(f"Tuning failed for {fs_name}/{mname}: {e}. Proceeding with base model.")
            if args.task == "regression":
                oof_path = None
                if args.save_oof:
                    oof_path = str(out_dir / f"oof_{args.task}_{fs_name}_{mname}.csv")
                metrics = evaluate_regression(model, ds.X, ds.y, cv_splits=args.cv_splits, random_state=args.seed, sample_weights=ds.w, return_oof=args.save_oof, oof_output_path=oof_path)
            else:
                oof_path = None
                if args.save_oof:
                    oof_path = str(out_dir / f"oof_{args.task}_{fs_name}_{mname}.csv")
                metrics = evaluate_classification(model, ds.X, ds.y, cv_splits=args.cv_splits, random_state=args.seed, sample_weights=ds.w, return_oof=args.save_oof, oof_output_path=oof_path)
            results[fs_name][mname] = metrics
            print(f"{fs_name} / {mname}: {metrics}")

    with open(out_dir / f"results_{args.task}.json", "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2)






def _pearson(y_true, y_pred):
    import numpy as np
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    if y_true.std() == 0 or y_pred.std() == 0:
        return 0.0
    return float(np.corrcoef(y_true, y_pred)[0, 1])


def _tune_model(task: str, mname: str, base_model, ds, cv_splits: int, out_dir: Path):
    from copy import deepcopy
    from sklearn.model_selection import KFold, StratifiedKFold
    from sklearn.isotonic import IsotonicRegression
    from sklearn.pipeline import Pipeline
    import numpy as np

    grids = {}
    if task == "regression":
        grids = {
            "ridge": {"est__alpha": [0.1, 0.3, 1.0, 3.0, 10.0]},
            "rf_reg": {"est__max_depth": [None, 6, 10, 16], "est__min_samples_leaf": [1, 3, 5]},
            "hgb_reg": {"est__learning_rate": [0.03, 0.1], "est__max_depth": [3, 6, 10]},
        }
    else:
        grids = {
            "logreg": {"est__C": [0.5, 1.0, 2.0]},
            "rf_clf": {"est__max_depth": [None, 8, 16], "est__min_samples_leaf": [1, 3, 5]},
            "hgb_clf": {"est__learning_rate": [0.03, 0.1], "est__max_depth": [3, 6, 10]},
        }
    if mname not in grids:
        return base_model

    best_model = base_model
    best_score = -1.0
    params = grids[mname]

    def param_product(d):
        import itertools
        keys = list(d.keys())
        for values in itertools.product(*[d[k] for k in keys]):
            yield dict(zip(keys, values))

    for p in param_product(params):
        model = deepcopy(base_model)
        model.set_params(**p)
        # simple CV on corr Pearson using per-fold isotonic calibration
        if task == "regression":
            kf = KFold(n_splits=cv_splits, shuffle=True, random_state=42)
            oof_pred_cal = []
            oof_y = []
            for train_idx, test_idx in kf.split(ds.X):
                X_train = [ds.X[i] for i in train_idx]
                y_train = np.asarray([ds.y[i] for i in train_idx])
                X_test = [ds.X[i] for i in test_idx]
                y_test = np.asarray([ds.y[i] for i in test_idx])
                sw_train = None if ds.w is None else np.asarray([ds.w[i] for i in train_idx])
                _fit_with_sample_weight(model, X_train, y_train, sample_weight=sw_train)
                pred_test = model.predict(X_test)
                try:
                    pred_train = model.predict(X_train)
                    iso = IsotonicRegression(out_of_bounds="clip")
                    try:
                        iso.fit(pred_train, y_train, sample_weight=sw_train)
                    except TypeError:
                        iso.fit(pred_train, y_train)
                    pred_test = iso.transform(pred_test)
                except Exception:
                    pass
                oof_pred_cal.extend(list(pred_test))
                oof_y.extend(list(y_test))
        else:
            skf = StratifiedKFold(n_splits=cv_splits, shuffle=True, random_state=42)
            y = np.asarray(ds.y)
            y_bin = (y > 3.0).astype(int)
            zeros = np.zeros_like(y_bin)
            aucs = []
            for train_idx, test_idx in skf.split(zeros, y_bin):
                X_train = [ds.X[i] for i in train_idx]
                y_train_bin = y_bin[train_idx]
                X_test = [ds.X[i] for i in test_idx]
                y_test_bin = y_bin[test_idx]
                sw_train = None if ds.w is None else np.asarray([ds.w[i] for i in train_idx])

                # Combine balanced class weights with provided sample weights, avoiding double balancing
                est = model
                if isinstance(model, Pipeline) and hasattr(model, "named_steps") and "est" in model.named_steps:
                    est = model.named_steps["est"]
                has_internal_cw = hasattr(est, "class_weight") and getattr(est, "class_weight", None) not in (None, {})
                try:
                    cw = None if has_internal_cw else compute_sample_weight(class_weight="balanced", y=y_train_bin)
                    if sw_train is not None and cw is not None:
                        sw_combined = cw * sw_train
                    elif sw_train is not None:
                        sw_combined = sw_train
                    else:
                        sw_combined = cw
                except Exception:
                    sw_combined = sw_train

                _fit_with_sample_weight(model, X_train, y_train_bin, sample_weight=sw_combined)

                if hasattr(model, "predict_proba"):
                    s_test = model.predict_proba(X_test)[:, 1]
                elif hasattr(model, "decision_function"):
                    s_test = model.decision_function(X_test)
                else:
                    continue
                try:
                    aucs.append(float(roc_auc_score(y_test_bin, s_test)))
                except Exception:
                    pass
        score = float(np.mean(aucs)) if len(aucs) > 0 else -1.0
        if score > best_score:
            best_score = score
            best_model = model
    return best_model


if __name__ == "__main__":
    main()
