# script/paper_experiment/main_experiment/run_const_lambda.py
# -*- coding: utf-8 -*-
"""
Lambda sweep for a single dataset (clari_tree_const, greedy_const, streedc, guide),
optionally restricted to a specific outer fold (via --outer_id).

Input:
  data/{dataset}/splits/outer_0..5/{train.csv,test.csv}

For each method:
  - clari_tree_const/greedy_const/streedc: sweep over lambda values
  - guide: sweep over max_nodes values

Output:
  ./results/CR_depth{D}/{dataset}/results_CR_d{D}_outer{outer_id}.csv

Columns:
  dataset, outer, method, depth, lambda,
  leaves, r2_train, r2_test, mse_train, mse_test, train_time_s
"""

from __future__ import annotations
from pathlib import Path
import argparse, csv, time, os
import numpy as np
from tqdm import tqdm

from script.processors.streed import STreeDConstProcessor
from script.processors.claritree import GreedyConstProcessor, CLARITreeConstProcessor
from script.utils.dataio import load_xy_claritree
from script.utils.utils import scorer_r2, scorer_mse
from script.processors.guide_c import GuideConstantProcessor
from script.processors.guide_utils import (
    update_guide_model_r, run_rscript, parse_r2,
    parse_guide_train_r2_and_elapsed, find_guide_training_out
)


GUIDE_DATASET_ALIASES: dict[str, str] = {
    # Short aliases to keep GUIDE R code file paths safely under its length limit.
    "california_housing": "ca",
    "temperature_min": "te_min",
    "temperature_max": "te_max",
}


def time_block(fn):
    """Run fn() and measure execution time (seconds)."""
    t0 = time.perf_counter()
    out = fn()
    t1 = time.perf_counter()
    return out, (t1 - t0)


def get_processors(methods: list[str] | None = None):
    """Return constant processors, optionally filtered by methods list."""
    all_processors = {
        "clari_tree_const": CLARITreeConstProcessor(),
        "greedy_const":    GreedyConstProcessor(),
        "streedc":      STreeDConstProcessor(),
        "guide":        GuideConstantProcessor(),
    }
    if methods is None:
        return all_processors
    return {k: v for k, v in all_processors.items() if k in methods}


def build_params(depth: int, lam: float, method: str, n_train: int, threshold_mode: str = "full") -> dict:
    """Build hyperparameter dictionary for constant processors."""
    params = {"depth": int(depth), "cost_complexity": float(lam)}
    
    # For clari_tree_const/greedy_const: control stride
    if method in {"clari_tree_const", "greedy_const"}:
        # full: stride=1 (all thresholds), threshold: roughly 20 thresholds
        params["stride"] = 1 if threshold_mode == "full" else max(1, n_train // 20)
    else:
        params["stride"] = 1
    
    # For streedc: control n_thresholds
    if method == "streedc":
        params["n_thresholds"] = n_train if threshold_mode == "full" else 20
    
    return params


def fit_and_eval(proc, method: str, X_tr, y_tr, X_te, y_te, depth: int, lam: float, threshold_mode: str = "full"):
    """Train + evaluate one run."""
    params = build_params(depth, lam, method, n_train=len(X_tr), threshold_mode=threshold_mode)

    def _fit():
        m = proc.build(**params)
        return proc.fit(m, X_tr, y_tr)

    art, t_fit = time_block(_fit)

    y_hat_tr = proc.predict(art.model, X_tr)
    y_hat_te = proc.predict(art.model, X_te)

    r2_tr = float(scorer_r2(y_tr, y_hat_tr))
    r2_te = float(scorer_r2(y_te, y_hat_te))
    mse_tr = float(scorer_mse(y_tr, y_hat_tr))
    mse_te = float(scorer_mse(y_te, y_hat_te))
    leaves = int(getattr(art, "complexity", 0))

    return {
        "leaves": leaves,
        "r2_train": r2_tr,
        "r2_test": r2_te,
        "mse_train": mse_tr,
        "mse_test": mse_te,
        "train_time_s": float(t_fit),
    }


def run_for_dataset(dataset_dir: Path, depth: int,
                    lambdas: list[float], guide_nodes: list[int],
                    out_csv: Path, outer_id: int | None = None,
                    methods: list[str] | None = None,
                    threshold_mode: str = "full"):
    """Run sweep for one dataset (optionally one outer fold)."""
    processors = get_processors(methods)
    splits_root = dataset_dir / "splits"

    # collect outer folds
    outers = sorted(
        [p for p in splits_root.iterdir() if p.is_dir() and p.name.startswith("outer_")],
        key=lambda p: int(p.name.split("_")[-1])
    )
    if outer_id is not None:
        outers = [p for p in outers if p.name == f"outer_{outer_id}"]

    dataset = dataset_dir.name
    out_csv.parent.mkdir(parents=True, exist_ok=True)

    # Read existing results if file exists and we're only running specific methods
    existing_rows = []
    existing_agg_rows = []  # mean/std rows
    if out_csv.exists() and methods is not None:
        try:
            with out_csv.open("r", newline="") as f:
                reader = csv.DictReader(f)
                for row in reader:
                    outer_val = row.get("outer", "")
                    row_method = row.get("method", "")
                    
                    # For aggregate rows (mean/std), keep if method is not in methods we're running
                    if outer_val in ("mean", "std"):
                        if row_method not in methods:
                            existing_agg_rows.append(row)
                    else:
                        # For regular rows:
                        # - If outer_id is specified, keep rows that are NOT from current outer_id OR not from methods we're running
                        # - If outer_id is None, keep rows that are NOT from methods we're running
                        if outer_id is not None:
                            # Keep if: different outer OR different method
                            if outer_val != f"outer_{outer_id}" or row_method not in methods:
                                existing_rows.append(row)
                        else:
                            # Keep if method is not in methods we're running
                            if row_method not in methods:
                                existing_rows.append(row)
        except Exception as e:
            print(f"[warning] Could not read existing file {out_csv}: {e}")

    with out_csv.open("w", newline="") as f:
        fieldnames = [
            "dataset","outer","method","depth","lambda",
            "leaves","r2_train","r2_test","mse_train","mse_test","train_time_s"
        ]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        
        # Write existing rows first (from other methods)
        for row in existing_rows:
            writer.writerow(row)

        agg = {}

        for odir in tqdm(outers, desc=f"{dataset}|outer folds"):
            k_outer = odir.name
            train_csv = odir / "train.csv"
            test_csv  = odir / "test.csv"
            if not train_csv.exists() or not test_csv.exists():
                print(f"[skip] {dataset}/{k_outer} missing train/test.csv")
                continue

            X_tr, y_tr = load_xy_claritree(train_csv)
            X_te, y_te = load_xy_claritree(test_csv)

            for method, proc in tqdm(processors.items(), desc=f"{dataset}|{k_outer}|methods", leave=False):
                if method == "guide":
                    # Use a short alias for certain long-named datasets to keep the
                    # GUIDE R code file path well below its internal length limit.
                    guide_tag = GUIDE_DATASET_ALIASES.get(dataset, dataset)
                    for max_nodes in tqdm(guide_nodes, desc=f"guide max_nodes", leave=False):
                        work_dir = out_csv.parent / f"GW_{guide_tag}_d{depth}_{k_outer}_n{max_nodes}"
                        work_dir.mkdir(parents=True, exist_ok=True)

                        m = proc.build(csv_path=train_csv, depth=depth,
                                       max_nodes=max_nodes, work_dir=work_dir)
                        art, t_fit = time_block(lambda: proc.fit(m, None, None))

                        rfile = work_dir / "guide_model.R"
                        update_guide_model_r(rfile)
                        out, _ = time_block(lambda: run_rscript(rfile))

                        try:
                            r2_te = parse_r2(out)
                        except Exception:
                            r2_te = float("nan")

                        train_out_path = find_guide_training_out(work_dir)
                        if train_out_path is not None:
                            r2_tr, elapsed = parse_guide_train_r2_and_elapsed(train_out_path)
                        else:
                            r2_tr, elapsed = parse_guide_train_r2_and_elapsed(out)

                        fit_time = elapsed if elapsed else t_fit

                        row = {
                            "dataset": dataset, "outer": k_outer,
                            "method": "guide", "depth": depth,
                            "lambda": max_nodes,
                            "leaves": int(getattr(art, "complexity", max_nodes)),
                            "r2_train": r2_tr, "r2_test": r2_te,
                            "mse_train": float("nan"), "mse_test": float("nan"),
                            "train_time_s": fit_time,
                        }
                        writer.writerow(row)

                        key = (method, max_nodes)
                        if key not in agg:
                            agg[key] = {k: [] for k in row.keys()
                                        if k not in {"dataset","outer","method","depth","lambda"}}
                        for k in agg[key].keys():
                            agg[key][k].append(row[k])
                else:
                    for lam in tqdm(lambdas, desc=f"{method} lambdas", leave=False):
                        res = fit_and_eval(proc, method, X_tr, y_tr, X_te, y_te, depth, lam, threshold_mode)
                        row = {
                            "dataset": dataset, "outer": k_outer,
                            "method": method, "depth": depth,
                            "lambda": lam, **res,
                        }
                        writer.writerow(row)

                        key = (method, lam)
                        if key not in agg:
                            agg[key] = {k: [] for k in res.keys()}
                        for k, v in res.items():
                            agg[key][k].append(v)

        # only compute mean/std if running all outer folds
        if outer_id is None:
            def mean(xs): return float(np.mean(xs)) if xs else float("nan")
            def std(xs): return float(np.std(xs, ddof=1)) if len(xs) > 1 else 0.0

            for (method, lam), stats in agg.items():
                mean_row = {"dataset": dataset, "outer": "mean",
                            "method": method, "depth": depth, "lambda": lam}
                std_row = {"dataset": dataset, "outer": "std",
                           "method": method, "depth": depth, "lambda": lam}
                for k, xs in stats.items():
                    mean_row[k] = mean(xs)
                    std_row[k] = std(xs)
                writer.writerow(mean_row)
                writer.writerow(std_row)
        
        # Write existing aggregate rows (from other methods) at the end
        for row in existing_agg_rows:
            writer.writerow(row)


def parse_nodes(arg: str) -> list[int]:
    """Parse --guide_nodes argument."""
    if "-" in arg:
        lo, hi = arg.split("-")
        return list(range(int(lo), int(hi)+1))
    return [int(x) for x in arg.split(",")]


def main():
    parser = argparse.ArgumentParser(
        description="Lambda sweep for clari_tree_const/greedy_const/streedc/guide on one dataset (constant).")
    parser.add_argument("--data_dir", type=str, required=True,
                        help="Dataset folder, e.g. data/airfoil (must contain /splits/outer_0..5/)")
    parser.add_argument("--depth", type=int, required=True,
                        help="Tree depth for all runs")
    parser.add_argument("--lambdas", type=str, required=True,
                        help="Comma-separated list of lambdas, e.g. '0.001,0.01,0.1,1.0'")
    parser.add_argument("--guide_nodes", type=str, default="1-32",
                        help="Range of max_nodes for GUIDE, e.g. '1-32' or '2,4,8,16'")
    parser.add_argument("--outer_id", type=int, default=None,
                        help="Run only a specific outer fold (0..5)")
    parser.add_argument("--results_dir", type=str, default=None,
                        help="Output directory. Default: ./results/CR_depth{depth}/{dataset}")
    parser.add_argument("--methods", type=str, default=None,
                        help="Comma-separated list of methods to run, e.g. 'guide' or 'clari_tree_const,guide'. Default: all methods")
    parser.add_argument("--threshold_mode", type=str, default="full",
                        choices=["threshold", "full"],
                        help="Threshold search mode for streedc/clari_tree_const/greedy_const. 'threshold' uses n_train//20 stride and 20 thresholds, 'full' uses all thresholds.")
    args = parser.parse_args()

    lambdas = [float(x) for x in args.lambdas.split(",")]
    dataset_dir = Path(args.data_dir)
    if not dataset_dir.exists():
        raise FileNotFoundError(f"Dataset folder {dataset_dir} not found")

    guide_nodes = parse_nodes(args.guide_nodes)
    
    methods = None
    if args.methods is not None:
        methods = [m.strip() for m in args.methods.split(",")]

    results_root = (
        Path(args.results_dir)
        if args.results_dir is not None
        else Path(f"./results/CR_depth{args.depth}/{dataset_dir.name}")
    )

    # include outer id in filename for consistency
    suffix = f"_outer{args.outer_id}" if args.outer_id is not None else ""
    out_csv = results_root / f"results_CR_d{args.depth}{suffix}.csv"

    print(f"[{dataset_dir.name}] -> {out_csv}")
    run_for_dataset(dataset_dir, args.depth, lambdas, guide_nodes, out_csv, args.outer_id, methods, args.threshold_mode)
    print("All done.")


if __name__ == "__main__":
    main()
