#!/usr/bin/env python3

from __future__ import annotations
import argparse
import os
import json
import csv
import re
import warnings
from typing import List, Dict
import torch
import numpy as np
from tqdm.auto import tqdm

# Import all utilities from steer_eval
from steer_eval import (
    dist_is_enabled, dist_init, dist_rank, dist_world, only_rank0, shard_list, set_global_determinism, load_preproc, invert_preproc_step,
    load_tok_mdl, pick_device, make_schedule, pick_consensus_layer, build_vec_bank_from_soft, auto_soft_json, load_hf_dataset_items, load_gpqa_diamond_items,
    LLMRunner, eval_batched_select_vec, reduce_counts,
)


def sanitize_name(s: str) -> str:
    return re.sub(r'[^A-Za-z0-9_.:-]+', '-', str(s))


def combo_id(alpha: float, schedule_name: str, layer_list: List[int]) -> str:
    lay = ",".join(map(str, layer_list))
    return f"a{alpha:g}__{sanitize_name(schedule_name)}__L{lay}"


def parse_alphas(s: str) -> List[float]:
    return [float(x) for x in s.split(",") if x.strip()]


def parse_schedules(s: str) -> List[str]:
    return [x.strip() for x in s.split(",") if x.strip()]


def parse_layer_sets(s: str) -> List[List[int]]:
    out = []
    for chunk in s.split("|"):
        chunk = chunk.strip()
        if not chunk:
            continue
        if "-" in chunk:
            a, b = chunk.split("-", 1)
            a, b = int(a), int(b)
            out.append(list(range(min(a, b), max(a, b) + 1)))
        else:
            out.append([int(x) for x in chunk.split(",") if x.strip()])
    return out


def main():
    ap = argparse.ArgumentParser(description="Grid search over steering parameters")
    
    # Input files
    ap.add_argument("--stats_npz", required=True, help="Stats NPZ file")
    ap.add_argument("--model_npz", required=True, help="Model NPZ file (scaler/PCA)")
    
    # Steering configuration
    ap.add_argument("--use_strategy", default="baseline_all_steps",
                    help="Baseline vector strategy")
    ap.add_argument("--steer_modes", default=None,
                    help="Comma-separated: vec_base,soft_prob,soft_argmax (default: all)")
    ap.add_argument("--soft_json", default=None,
                    help="Path to soft_edges_top3.json (auto-detect if not provided)")
    ap.add_argument("--edge_vec_prefix", default="edge_delta",
                    help="Prefix for edge vectors in stats")
    
    # Model configuration
    ap.add_argument("--gen_model", default="bespokelabs/Bespoke-Stratos-7B")
    ap.add_argument("--tokenizer", default=None)
    ap.add_argument("--device", default="cuda:0")
    ap.add_argument("--dtype", default="float16", choices=["float16", "bfloat16", "float32"])
    
    # Generation parameters
    ap.add_argument("--gen_temperature", type=float, default=0.6)
    ap.add_argument("--gen_top_p", type=float, default=0.95)
    ap.add_argument("--gen_top_k", type=int, default=None)
    ap.add_argument("--min_p", type=float, default=None)
    ap.add_argument("--gen_max_new_tokens", type=int, default=2000)
    ap.add_argument("--system_text", default="detailed thinking on")
    ap.add_argument("--final_boxed_hint", action="store_true")
    
    # Dataset selection (HuggingFace)
    ap.add_argument("--hf_dataset", default=None)
    ap.add_argument("--hf_config", default=None)
    ap.add_argument("--hf_split", default="test")
    ap.add_argument("--hf_prompt_key", default="question")
    ap.add_argument("--hf_answer_key", default="answer")
    ap.add_argument("--hf_seed", type=int, default=None)
    ap.add_argument("--hf_skip_first", type=int, default=0)
    ap.add_argument("--hf_filter_answer_types", default=None)
    ap.add_argument("--hf_filter_difficulties", default=None)
    ap.add_argument("--eval_diamond", action="store_true")
    ap.add_argument("--diamond_split", choices=["train", "test"], default="train")
    ap.add_argument("--diamond_n", type=int, default=100)
    ap.add_argument("--diamond_seed", type=int, default=None)
    ap.add_argument("--diamond_skip_first", type=int, default=0)
    
    ap.add_argument("--max_eval", type=int, default=100)
    ap.add_argument("--metric", choices=["em", "numeric", "regex"], default="numeric") # will rejudge using another code
    ap.add_argument("--regex_answer", default=None)
    ap.add_argument("--regex_pred", default=None)
    ap.add_argument("--progress", action="store_true")
    
    # Grid parameters
    ap.add_argument("--alphas", required=True, help="Comma-separated alpha values")
    ap.add_argument("--schedules", required=True, help="Comma-separated schedule names")
    ap.add_argument("--layer_sets", required=True,
                    help="Pipe-separated layer sets (e.g., '10-20|5,15|30')")
    
    # Auto layer selection
    ap.add_argument("--auto_layers",
                    default="none",
                    help="Override layer_sets with consensus layer(s) from stats. "
                         "Options: 'none', 'consensus_first_change', 'consensus_last_change', "
                         "or comma-separated combination like 'consensus_first_change,consensus_last_change'")
    
    # Output configuration
    ap.add_argument("--out_json", required=True)
    ap.add_argument("--out_csv", default=None)
    ap.add_argument("--details_dir", default=None,
                    help="Directory for per-run detailed results")
    
    # Runtime configuration
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--batch_size", type=int, default=1)
    ap.add_argument("--step_aware", action="store_true",
                    help="Enable step-aware steering (detect thinking completion)")
    
    args = ap.parse_args()
    
    # ========================================================================
    # Setup
    # ========================================================================
    if args.steer_modes is None:
        modes = ["vec_base", "soft_prob", "soft_argmax"]
    else:
        modes = [m.strip() for m in args.steer_modes.split(",") if m.strip()]
    modes = [m for m in modes if m in {"vec_base", "soft_prob", "soft_argmax"}]
    if not modes:
        raise SystemExit("[error] No valid steer modes selected")

    dist_init()
    set_global_determinism(args.seed, strict=False)
    
    if args.details_dir and only_rank0():
        os.makedirs(args.details_dir, exist_ok=True)
    
    if args.auto_layers != "none":
        strategies = [s.strip() for s in args.auto_layers.split(",") if s.strip()]
        valid_strategies = ["consensus_first_change", "consensus_last_change"]
        
        layers_to_use = []
        for strategy in strategies:
            if strategy not in valid_strategies:
                raise SystemExit(
                    f"[error] Invalid auto_layers strategy: {strategy}. "
                    f"Must be one of: {', '.join(valid_strategies)}"
                )
            
            key = (
                "consensus_first_change_layers"
                if strategy == "consensus_first_change"
                else "consensus_last_change_layers"
            )
            L = pick_consensus_layer(args.stats_npz, which=key)
            if L is None:
                raise SystemExit(
                    f"[error] auto_layers={strategy} requested, "
                    f"but {key} not found/valid in stats"
                )
            layers_to_use.append(L)
            if only_rank0():
                print(f"[auto_layers] {strategy} → layer {L}")

        layers_to_use = sorted(set(layers_to_use))
        
        args.layer_sets = "|".join(map(str, layers_to_use))
        if only_rank0():
            print(f"[auto_layers] Final layer set: {args.layer_sets}")
            print(f"[auto_layers] This will create {len(layers_to_use)} separate layer runs per mode")
    
    # ========================================================================
    # Load dataset
    # ========================================================================
    if args.hf_dataset:
        items_full = None
        if only_rank0():
            items_full = load_hf_dataset_items(
                ds_name=args.hf_dataset,
                ds_config=args.hf_config,
                split=args.hf_split,
                prompt_key=args.hf_prompt_key,
                answer_key=args.hf_answer_key,
                n=args.max_eval,
                seed=(args.hf_seed if args.hf_seed is not None else args.seed),
                skip_first=args.hf_skip_first,
                filter_answer_types=(
                    args.hf_filter_answer_types.split(",")
                    if args.hf_filter_answer_types
                    else None
                ),
                filter_difficulties=(
                    args.hf_filter_difficulties.split(",")
                    if args.hf_filter_difficulties
                    else None
                ),
            )
        # Broadcast to all ranks
        if dist_world() > 1:
            
            obj = [items_full]
            torch.distributed.broadcast_object_list(obj, src=0)
            items_full = obj[0]
    elif args.eval_diamond:
        items_full = load_gpqa_diamond_items(
            n=args.diamond_n,
            split=args.diamond_split,
            seed=(args.diamond_seed if args.diamond_seed is not None else args.seed),
            skip_first=args.diamond_skip_first
        )
        if args.metric != "numeric":
            args.metric = "regex"
    else:
        items_full = []
    
    if not items_full:
        if only_rank0():
            print("[grid] No evaluation items")
        return
    
    # Shard dataset across ranks
    rank, world = dist_rank(), dist_world()
    items = shard_list(items_full, rank, world)
    
    # ========================================================================
    # Initialize runner
    # ========================================================================
    runner = LLMRunner(
        model_name=args.gen_model,
        tokenizer_name=args.tokenizer,
        temperature=args.gen_temperature,
        top_p=args.gen_top_p,
        max_new_tokens=args.gen_max_new_tokens,
        device=args.device,
        dtype=args.dtype,
        top_k=args.gen_top_k,
        system_text=args.system_text,
        final_boxed_hint=args.final_boxed_hint,
        min_p=args.min_p,
    )
    
    # ========================================================================
    # Baseline evaluation
    # ========================================================================
    if only_rank0():
        print(f"[grid] Running baseline evaluation on {len(items)} items (rank {rank})")
    
    base_acc, base_rows, base_total_gen_toks, _ = eval_batched_select_vec(
        runner,
        items,
        metric=args.metric,
        schedule=None,
        base_seed=args.seed,
        batch_size=args.batch_size,
        step_aware=False,
        mode="none",
        regex_answer=args.regex_answer,
        regex_pred=args.regex_pred,
        show_progress=args.progress
    )
    
    # Save baseline details
    if args.details_dir:
        with open(os.path.join(args.details_dir, f"baseline.rank{rank}.json"), "w") as f:
            json.dump({
                "rank": rank,
                "kind": "baseline",
                "n_local": len(items),
                "rows": base_rows
            }, f, indent=2)
    
    # Reduce baseline across ranks
    local_base_correct = int(round(base_acc * len(items)))
    g_base_acc, g_total_N, g_base_tok_total = reduce_counts(
        args.device, local_base_correct, len(items), int(base_total_gen_toks)
    )
    
    if only_rank0():
        print(
            f"[grid-baseline] model={args.gen_model} "
            f"| dataset={args.hf_dataset or ('diamond' if args.eval_diamond else 'custom')} "
            f"| seed={args.seed} → acc={g_base_acc:.4f}, N={g_total_N}, "
            f"gen_toks_total={g_base_tok_total}"
        )
    
    # ========================================================================
    # Load steering vectors
    # ========================================================================
    alphas = parse_alphas(args.alphas)
    schedules = parse_schedules(args.schedules)
    layer_sets = parse_layer_sets(args.layer_sets)
    results = []
    
    vec_hidden_single = None
    if "vec_base" in modes:
        z = np.load(args.stats_npz, allow_pickle=True)
        key = f"vec::{args.use_strategy}"
        if key not in z.files:
            raise SystemExit(f"[error] {key} not found in {args.stats_npz}")
        scaler, pca = load_preproc(args.model_npz)
        vec_hidden_single = invert_preproc_step(z[key], scaler, pca)
    
    vec_bank, vec_probs = None, None
    if "soft_prob" in modes or "soft_argmax" in modes:
        soft_json = args.soft_json or auto_soft_json(args.stats_npz)
        if not soft_json or not os.path.isfile(soft_json):
            raise SystemExit(
                "[error] soft JSON not found. Pass --soft_json or place "
                "soft_edges_top3.json next to stats"
            )
        vec_bank, vec_probs = build_vec_bank_from_soft(
            args.stats_npz, args.model_npz, soft_json, prefix=args.edge_vec_prefix
        )
    
    # ========================================================================
    # Grid search
    # ========================================================================
    for mode in modes:
        sel_mode = {
            "vec_base": "single",
            "soft_prob": "prob",
            "soft_argmax": "argmax"
        }[mode]
        
        sweep_iter = [
            (a, s, ls)
            for a in alphas
            for s in schedules
            for ls in layer_sets
        ]
        
        if only_rank0() and args.progress:
            sweep_iter = tqdm(sweep_iter, desc=f"Grid[{mode}]", unit="cfg")
        
        for alpha, sched_name, layer_list in sweep_iter:
            schedule = make_schedule("linear", layer_list, alpha)
            
            steer_acc, steer_rows, steer_total_gen_toks, _ = eval_batched_select_vec(
                runner,
                items,
                metric=args.metric,
                schedule=schedule,
                base_seed=args.seed,
                batch_size=args.batch_size,
                step_aware=args.step_aware,
                mode=sel_mode,
                vec_hidden_single=vec_hidden_single,
                vec_bank=vec_bank,
                vec_probs=vec_probs,
                regex_answer=args.regex_answer,
                regex_pred=args.regex_pred,
                show_progress=False
            )
            
            # Save details
            combo = combo_id(alpha, "linear", layer_list)
            if args.details_dir:
                with open(
                    os.path.join(args.details_dir, f"{combo}.{mode}.rank{rank}.json"),
                    "w"
                ) as f:
                    json.dump({
                        "rank": rank,
                        "kind": f"steered:{mode}",
                        "alpha": float(alpha),
                        "schedule": "linear",
                        "layers": layer_list,
                        "n_local": len(items),
                        "rows": steer_rows
                    }, f, indent=2)
            
            # Reduce across ranks
            local_steer_correct = int(round(steer_acc * len(items)))
            g_steer_acc, gN, g_steer_tok_total = reduce_counts(
                args.device, local_steer_correct, len(items), int(steer_total_gen_toks)
            )
            
            if only_rank0():
                res = {
                    "mode": mode,
                    "alpha": float(alpha),
                    "schedule": "linear",
                    "layers": layer_list,
                    "n": int(gN),
                    "baseline_acc": float(g_base_acc),
                    "steered_acc": float(g_steer_acc),
                    "delta_acc": float(g_steer_acc - g_base_acc),
                    "baseline_gen_tokens_total": int(g_base_tok_total),
                    "steered_gen_tokens_total": int(g_steer_tok_total),
                    "delta_gen_tokens_avg": (
                        (g_steer_tok_total - g_base_tok_total) / max(1, int(gN))
                    ),
                    "use_strategy": args.use_strategy if mode == "vec_base" else None,
                    "edge_vec_prefix": (
                        args.edge_vec_prefix if mode.startswith("soft_") else None
                    ),
                }
                results.append(res)
                print(
                    f"[grid-improve:{mode}] model={args.gen_model} "
                    f"| dataset={args.hf_dataset or ('diamond' if args.eval_diamond else 'custom')} "
                    f"| seed={args.seed} | linear | L={layer_list} | α={alpha} "
                    f"→ acc={g_steer_acc:.4f} (Δ={res['delta_acc']:+.4f}), "
                    f"Δtok_avg={res['delta_gen_tokens_avg']:+.2f}"
                )
    
    # ========================================================================
    # Save results
    # ========================================================================
    if only_rank0():
        os.makedirs(os.path.dirname(args.out_json), exist_ok=True)
        with open(args.out_json, "w") as f:
            json.dump({
                "baseline": {
                    "acc": g_base_acc,
                    "n": g_total_N,
                    "gen_tokens_total": g_base_tok_total
                },
                "runs": results
            }, f, indent=2)
        print(f"[grid] wrote {args.out_json}")
        
        if args.out_csv:
            os.makedirs(os.path.dirname(args.out_csv), exist_ok=True)
            all_keys = set()
            for r in results:
                all_keys.update(r.keys())
            preferred = [
                "mode", "alpha", "schedule", "layers", "n",
                "baseline_acc", "steered_acc", "delta_acc",
                "baseline_gen_tokens_total", "steered_gen_tokens_total",
                "delta_gen_tokens_avg",
                "use_strategy", "edge_vec_prefix"
            ]
            cols = [k for k in preferred if k in all_keys] + [
                k for k in sorted(all_keys) if k not in preferred
            ]
            with open(args.out_csv, "w", newline="") as f:
                w = csv.DictWriter(f, fieldnames=cols, extrasaction="ignore")
                w.writeheader()
                for r in results:
                    row = {}
                    for k in cols:
                        v = r.get(k)
                        if k == "layers" and isinstance(v, (list, tuple)):
                            v = ",".join(map(str, v))
                        row[k] = v
                    w.writerow(row)
            print(f"[grid] wrote {args.out_csv}")


if __name__ == "__main__":
    try:
        main()
    finally:
        try:
            import torch.distributed as dist
            if dist_is_enabled():
                dist.barrier()
                dist.destroy_process_group()
        except Exception:
            pass