# script/paper_experiment/parameter_selection/binary_search.py
# -*- coding: utf-8 -*-
"""
Find cost_complexity (lambda) values for constant-leaf and linear-leaf trees
such that the *average* number of leaves across outer splits falls into
given buckets.

Supports:
  - clari_tree_const   (constant-leaf, C++)
  - greedy_const       (constant-leaf, C++)
  - clari_tree         (linear-leaf, C++)
  - greedy             (linear-leaf, C++)

Example:
  python script/paper_experiment/parameter_selection/binary_search.py \
      --dataset auction --depth 4 \
      --buckets 3-5,5-8,8-12 \
      --kappa 1e-5
"""

from __future__ import annotations
from pathlib import Path
from typing import Tuple, List, Dict
import argparse, math, csv
import numpy as np

# --- Processors / loaders ---
from script.processors.claritree import (
    GreedyConstProcessor,
    CLARITreeConstProcessor,
    GreedyProcessor,
    CLARITreeProcessor,
)
from script.processors.streed import STreeDConstProcessor, STreeDProcessor
from script.utils.dataio import load_xy_claritree

# ===============================
# Config
# ===============================
DEFAULT_DEPTH = 4
DEFAULT_BUCKETS: List[Tuple[int, int]] = [(0, 4), (4, 7), (7, 10)]

CCP_MIN = 1e-9
CCP_MAX = 1.0
CCP_BRACKET_EXPANSIONS = 10
CCP_BINSEARCH_STEPS = 16

DEFAULT_KAPPA = 1e-16   # a small ridge/jitter; good starting point for linear leaves


# ===============================
# Basic helpers
# ===============================
def get_processors() -> Dict[str, object]:
    """
    All tree variants we want to search over (constant + linear).
    Drop any you don't care about.
    """
    return {
        # constant-leaf trees
        "clari_tree_const": CLARITreeConstProcessor(),
        "greedy_const":    GreedyConstProcessor(),
        "streedc":      STreeDConstProcessor(),
        # linear-leaf trees
        "clari_tree":   CLARITreeProcessor(),
        "greedy": GreedyProcessor(),
        "streed":       STreeDProcessor(),
    }


def build_params(depth: int,
                 cost_complexity: float,
                 kappa: float) -> Dict[str, object]:
    """
    Common hyperparameter dict for all trees.

    depth:
        max tree depth (shared)
    cost_complexity:
        CCP pruning strength (lambda)
    kappa:
        linear-leaf regularization / jitter.
        For constant trees, this will typically be ignored by the C++ side
        if the mapping doesn't use it.
    """
    return {
        "depth": int(depth),
        "cost_complexity": float(cost_complexity),
        "kappa": float(kappa),
    }


def leaves_for_ccp(proc,
                   X_tr,
                   y_tr,
                   depth: int,
                   ccp: float,
                   kappa: float,
                   threshold_mode: str = "full") -> int:
    """
    Fit once with given CCP and return leaf count (art.complexity).
    """
    params = build_params(depth=depth, cost_complexity=ccp, kappa=kappa)

    # Speed/threshold controls:
    #   - For C++ linear trees (clari_tree/greedy), use stride ~= n/5 when in
    #     "threshold" mode so we only evaluate ~5 thresholds per feature.
    #   - For STreeD, use n_thresholds = n (full) or 5 (threshold), and for linear STreeD
    #     interpret kappa as ridge_penalty.
    n_train = getattr(X_tr, "shape", [len(X_tr)])[0]

    # C++ linear trees: control stride
    if isinstance(proc, (CLARITreeProcessor, GreedyProcessor)):
        if threshold_mode == "threshold":
            params["stride"] = max(1, n_train // 5)
        else:
            params["stride"] = 1

    # STreeD variants: control n_thresholds (+ regularization for linear)
    if isinstance(proc, STreeDConstProcessor):
        params["n_thresholds"] = n_train if threshold_mode == "full" else 5
    elif isinstance(proc, STreeDProcessor):
        params["n_thresholds"] = n_train if threshold_mode == "full" else 5
        params.setdefault("ridge_penalty", float(kappa))
        params.setdefault("lasso_penalty", 0.0)

    m = proc.build(**params)
    art = proc.fit(m, X_tr, y_tr)
    return int(getattr(art, "complexity", 0))


# ===============================
# Average-leaf oracle over CV folds
# ===============================
def load_outer_train_folds(dataset_dir: Path):
    """
    Return list of (X_tr, y_tr) over all outer_* splits (train only).
    """
    folds = []
    splits_root = dataset_dir / "splits"
    outers = sorted(
        [p for p in splits_root.iterdir()
         if p.is_dir() and p.name.startswith("outer_")],
        key=lambda p: int(p.name.split("_")[-1])
    )
    for odir in outers:
        train_csv = odir / "train.csv"
        if not train_csv.exists():
            print(f"[skip] {dataset_dir.name}/{odir.name} missing train.csv")
            continue
        X_tr, y_tr = load_xy_claritree(train_csv)
        folds.append((X_tr, y_tr))
    return folds


def avg_leaves_for_ccp(proc,
                       folds,
                       depth: int,
                       ccp: float,
                       kappa: float,
                       threshold_mode: str = "full") -> float:
    """
    For a given CCP, compute the mean leaf count over all outer folds.
    """
    vals = [
        leaves_for_ccp(proc, X_tr, y_tr, depth, ccp, kappa, threshold_mode=threshold_mode)
        for (X_tr, y_tr) in folds
    ]
    if not vals:
        return float("nan")
    return float(np.mean(vals))


# ===============================
# CCP search on *average* leaves
# ===============================
def bracket_ccp_for_bucket_cv(proc,
                              folds,
                              depth: int,
                              kappa: float,
                              lo_leaves: int,
                              hi_leaves: int,
                              ccp_min: float = CCP_MIN,
                              ccp_max: float = CCP_MAX,
                              max_expansions: int = CCP_BRACKET_EXPANSIONS,
                              threshold_mode: str = "full"):
    """
    Find CCP bounds [c_lo, c_hi] such that avg_leaves(c_lo) >= hi_leaves
    and avg_leaves(c_hi) <= lo_leaves, expanding multiplicatively if needed.

    Returns:
      (c_lo, c_hi, L_lo, L_hi)
    """
    c_lo = float(ccp_min)
    c_hi = float(ccp_max)

    L_lo = avg_leaves_for_ccp(proc, folds, depth, c_lo, kappa, threshold_mode=threshold_mode)
    L_hi = avg_leaves_for_ccp(proc, folds, depth, c_hi, kappa, threshold_mode=threshold_mode)

    expansions = 0
    while expansions < max_expansions:
        ok_low = (L_lo >= hi_leaves)
        ok_high = (L_hi <= lo_leaves)
        if ok_low and ok_high:
            break

        if not ok_low:
            # Not enough leaves at low CCP -> decrease CCP toward 0
            new_c_lo = max(CCP_MIN, c_lo * 0.1)
            if not math.isclose(new_c_lo, c_lo, rel_tol=0, abs_tol=1e-16):
                c_lo = new_c_lo
                L_lo = avg_leaves_for_ccp(
                    proc, folds, depth, c_lo, kappa, threshold_mode=threshold_mode
                )

        if not ok_high:
            # Too many leaves at high CCP -> increase CCP
            new_c_hi = min(CCP_MAX, c_hi * 10.0)
            if not math.isclose(new_c_hi, c_hi, rel_tol=0, abs_tol=1e-16):
                c_hi = new_c_hi
                L_hi = avg_leaves_for_ccp(
                    proc, folds, depth, c_hi, kappa, threshold_mode=threshold_mode
                )

        expansions += 1

    return c_lo, c_hi, L_lo, L_hi


def search_ccp_for_bucket_cv(proc,
                             folds,
                             depth: int,
                             kappa: float,
                             lo_leaves: int,
                             hi_leaves: int,
                             threshold_mode: str = "full"):
    """
    Find CCP (lambda) such that the *average* leaves over folds lies in [lo, hi].

    Strategy:
      1) Bracket CCP range on avg leaves.
      2) Binary search in CCP space.
      3) Return the CCP whose avg leaf-count is closest to bucket center,
         preferring those inside the bucket.
    """
    c_lo, c_hi, L_lo, L_hi = bracket_ccp_for_bucket_cv(
        proc, folds, depth, kappa, lo_leaves, hi_leaves,
        threshold_mode=threshold_mode,
    )

    center = 0.5 * (lo_leaves + hi_leaves)
    candidates = []

    if lo_leaves <= L_lo <= hi_leaves:
        candidates.append((c_lo, L_lo))
    if lo_leaves <= L_hi <= hi_leaves:
        candidates.append((c_hi, L_hi))

    ok_bracket = (L_lo >= hi_leaves) and (L_hi <= lo_leaves)
    if ok_bracket:
        left, right = c_lo, c_hi
    else:
        left, right = min(c_lo, c_hi), max(c_lo, c_hi)

    for _ in range(CCP_BINSEARCH_STEPS):
        mid = math.sqrt(left * right)  # geometric mid
        L_mid = avg_leaves_for_ccp(proc, folds, depth, mid, kappa, threshold_mode=threshold_mode)
        candidates.append((mid, L_mid))

        if ok_bracket:
            if L_mid >= hi_leaves:
                left = mid
            elif L_mid <= lo_leaves:
                right = mid
            else:
                # inside bucket; gently shrink around mid
                right = mid
        else:
            # Not perfectly bracketed; just push toward bucket center
            if L_mid > center:
                left = mid
            else:
                right = mid

    # Pick CCP whose avg leaves is inside [lo, hi] and closest to center;
    # if none inside, choose overall closest to center.
    def score(leaves: float):
        inside = (lo_leaves <= leaves <= hi_leaves)
        return (0 if inside else 1, abs(leaves - center))

    best_ccp, best_leaves = min(candidates, key=lambda x: score(x[1]))
    return float(best_ccp), float(best_leaves)


# ===============================
# Main
# ===============================
def main():
    parser = argparse.ArgumentParser(
        description="Find CCP (lambda) per bucket based on average leaf-count over outer folds."
    )
    parser.add_argument("--data_root", type=str, default="./data",
                        help="Root folder containing dataset subfolders.")
    parser.add_argument("--dataset", type=str, required=True,
                        help="Dataset name (subfolder of data_root).")
    parser.add_argument("--depth", type=int, default=DEFAULT_DEPTH,
                        help="Tree depth to use.")
    parser.add_argument("--buckets", type=str,
                        default=",".join(f"{a}-{b}" for a, b in DEFAULT_BUCKETS),
                        help="Comma-separated buckets like '3-5,5-8,8-12'.")
    parser.add_argument("--kappa", type=float, default=DEFAULT_KAPPA,
                        help="Linear-leaf regularization (shared across methods).")
    parser.add_argument("--threshold_mode", type=str, default="full",
                        choices=["full", "threshold"],
                        help="Threshold search mode for STreeD (and ignored by others).")
    parser.add_argument("--model_type", type=str, default="both",
                        choices=["constant", "linear", "both"],
                        help="Which family of models to run: constant, linear, or both.")
    parser.add_argument("--results_root", type=str, default=None,
                        help=(
                            "Root folder under which to store per-dataset constant/linear "
                            "CCP bucket settings. "
                            "Default: ./results/CCP_buckets_depth{depth}/{dataset}/"
                        ))
    args = parser.parse_args()

    # Parse buckets
    buckets: List[Tuple[int, int]] = []
    for tok in args.buckets.split(","):
        lo, hi = tok.split("-")
        buckets.append((int(lo), int(hi)))

    dataset_dir = Path(args.data_root) / args.dataset
    folds = load_outer_train_folds(dataset_dir)
    if not folds:
        raise RuntimeError(f"No outer train folds found under {dataset_dir}/splits")

    # Collect results for saving: split into constant vs linear methods
    constant_methods = {"clari_tree_const", "greedy_const", "streedc"}
    linear_methods = {"clari_tree", "greedy", "streed"}

    # Select which processors to run based on model_type
    all_procs = get_processors()
    if args.model_type == "constant":
        processors = {m: p for m, p in all_procs.items() if m in constant_methods}
    elif args.model_type == "linear":
        processors = {m: p for m, p in all_procs.items() if m in linear_methods}
    else:
        processors = all_procs

    print(f"Dataset: {args.dataset}, depth={args.depth}, kappa={args.kappa:g}")
    print(f"Buckets: {buckets}")
    print(f"Model type: {args.model_type}")
    print(f"Methods: {', '.join(processors.keys())}")
    print()

    const_records: List[Dict[str, object]] = []
    linear_records: List[Dict[str, object]] = []

    for method, proc in processors.items():
        print(f"=== Method: {method} ===")
        for (bl, bh) in buckets:
            ccp, mean_leaves = search_ccp_for_bucket_cv(
                proc,
                folds,
                depth=args.depth,
                kappa=args.kappa,
                lo_leaves=bl,
                hi_leaves=bh,
                threshold_mode=args.threshold_mode,
            )
            print(
                f"bucket [{bl}, {bh}]: "
                f"lambda(cost_complexity) = {ccp:.6g}, "
                f"avg_leaves = {mean_leaves:.2f}"
            )

            record = {
                "dataset": args.dataset,
                "depth": int(args.depth),
                "kappa": float(args.kappa),
                "threshold_mode": args.threshold_mode,
                "method": method,
                "bucket_lo": int(bl),
                "bucket_hi": int(bh),
                "lambda": float(ccp),
                "avg_leaves": float(mean_leaves),
            }
            if method in constant_methods:
                const_records.append(record)
            if method in linear_methods:
                linear_records.append(record)
        print()

    # -------------------------------
    # Save results under results/...
    # -------------------------------
    if const_records or linear_records:
        if args.results_root is None:
            results_root = Path(f"./results/CCP_buckets_depth{args.depth}") / args.dataset
        else:
            results_root = Path(args.results_root) / args.dataset
        results_root.mkdir(parents=True, exist_ok=True)

        def _write_csv(path: Path, rows: List[Dict[str, object]]):
            if not rows:
                return
            fieldnames = [
                "dataset",
                "depth",
                "kappa",
                "threshold_mode",
                "method",
                "bucket_lo",
                "bucket_hi",
                "lambda",
                "avg_leaves",
            ]
            with path.open("w", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=fieldnames)
                writer.writeheader()
                for r in rows:
                    writer.writerow(r)

        const_path = results_root / "ccp_buckets_constant.csv"
        linear_path = results_root / "ccp_buckets_linear.csv"
        _write_csv(const_path, const_records)
        _write_csv(linear_path, linear_records)


if __name__ == "__main__":
    main()
