import os
import numpy as np
import pandas as pd
from argparse import ArgumentParser, Namespace

from util import config_to_action, get_model_files
from config_generation.util import RESULTS_DIR, CONFIGS_DIR
from solve import solve as lowlevel_solve
from primal_hint import SCIPPrimalHint, GurobiPrimalHint


# =====================
# Seeding utilities
# =====================

def make_seed(seed_base: int, model_idx: int, seed_shift_per_model: int) -> int:
    """Deterministically construct a per-model seed."""
    return int(seed_base + model_idx * seed_shift_per_model)


# =====================
# Modes
# =====================

def run_baseline(args: Namespace, model_files: list[str]) -> pd.DataFrame:
    """Run default configuration once per model. Returns a long-form DataFrame."""
    solve_args = {
        "solver": args.solver,
        "time_limit": args.time_limit,
        "gap_limit": args.gap_limit,
    }
    rows = []
    for mi, mf in enumerate(model_files):
        print(os.path.basename(mf), flush=True)
        seed = make_seed(args.seed_base, mi, args.seed_shift_per_model)
        t = lowlevel_solve((None, mf, seed), solve_args, float("inf"))
        rows.append({"model_file": mf, "time": float(t)})
        print(t, flush=True)
    return pd.DataFrame(rows)


def run_mab(args: Namespace, actions: np.ndarray, config_files: np.ndarray,
            model_files: list[str]) -> pd.DataFrame:
    """Online UCB1 over a reduced configuration set (actions)."""
    # Independent hint buffers per arm.
    if args.solver.lower() == "scip":
        hint_mgrs = {
            ai: SCIPPrimalHint(
                k=getattr(args, "hint_k", 10),
                strategy=getattr(args, "hint_mode", "obj"),
                lp_repair_time=getattr(args, "hint_lp_seconds", 5.0),
                int_coverage_min=getattr(args, "hint_cov_min", 0.95),
            )
            for ai in range(len(actions))
        }
    elif args.solver.lower() == "gurobi":
        hint_mgrs = {
            ai: GurobiPrimalHint(
                k=getattr(args, "hint_k", 10),
                strategy=getattr(args, "hint_mode", "obj"),
            )
            for ai in range(len(actions))
        }
    else:
        hint_mgrs = {}

    num_actions = len(actions)
    counts = np.zeros(num_actions, dtype=int)
    means = np.zeros(num_actions, dtype=float)

    def choose_arm_ucb1(t: int) -> int:
        # Pull each arm once.
        for a in range(num_actions):
            if counts[a] == 0:
                return a
        bonuses = args.alpha * np.sqrt(2.0 * np.log(max(2, t)) / np.maximum(1, counts))
        return int(np.argmax(means + bonuses))

    base_args = {"solver": args.solver, "time_limit": args.time_limit, "gap_limit": args.gap_limit}
    rows = []
    t_round = 0
    for mi, mf in enumerate(model_files):
        print(os.path.basename(mf), flush=True)
        t_round += 1
        a = choose_arm_ucb1(t_round)
        seed = make_seed(args.seed_base, mi, args.seed_shift_per_model)
        hint_mgr = hint_mgrs.get(a)
        rt = lowlevel_solve((actions[a], mf, seed), base_args, float("inf"), hint_mgr)
        print(rt, flush=True)

        reward = -float(rt)
        counts[a] += 1
        means[a] += (reward - means[a]) / counts[a]
        rows.append(
            {
                "model_file": mf,
                "chosen_arm": int(a),
                "chosen_config": os.path.basename(config_files[a]),
                "time": float(rt),
            }
        )
    return pd.DataFrame(rows)


# =====================
# I/O helpers
# =====================

def _actions_fn(args: Namespace):
    """Load config YAMLs and convert them to solver actions."""
    config_folder = os.path.join(CONFIGS_DIR, args.instance_name, args.solver, args.config_name)
    config_files = np.sort(
        np.array(
            [
                os.path.join(config_folder, file)
                for file in os.listdir(config_folder)
                if file.endswith(".yaml")
            ]
        )
    )
    actions = np.array([config_to_action(file, args.solver) for file in config_files])
    return actions, config_files


def _save_named(args: Namespace, df: pd.DataFrame, filename: str) -> None:
    """Save CSV under RESULTS_DIR/... with seed tag to avoid overwriting."""
    output_folder = os.path.join(RESULTS_DIR, args.instance_name, args.solver, args.config_name)
    os.makedirs(output_folder, exist_ok=True)

    base, ext = os.path.splitext(filename)
    tagged = f"{base}_seed{args.seed_base}{ext}"
    outp = os.path.join(output_folder, tagged)

    df.to_csv(outp, index=False)
    print(f"Saved -> {outp}")


# =====================
# CLI
# =====================
if __name__ == "__main__":
    parser = ArgumentParser()
    # Data settings.
    parser.add_argument("--instance_name", type=str, required=True, help="MILP class (folder name)")
    parser.add_argument("--config_name", type=str, required=True, help="name of configuration set to use")

    # Eval selection: baseline or mab
    parser.add_argument(
        "--mode",
        type=str,
        choices=["baseline", "mab"],
        required=True,
        help="baseline: default only; mab: online UCB1 over models",
    )

    # Solver settings.
    parser.add_argument("--solver", type=str, default="gurobi", help="which MILP solver to use")
    parser.add_argument("--gap_limit", type=float, default=0.0, help="gap limit for solving instances")
    parser.add_argument("--time_limit", type=float, default=400.0, help="time/work budget")

    # Seeding.
    parser.add_argument("--seed_base", type=int, default=1729)
    parser.add_argument("--seed_shift_per_model", type=int, default=100000)

    # MAB hyperparameter.
    parser.add_argument("--alpha", type=float, default=1.0)

    # Compatibility with old util.get_model_files (uses args.eval_type).
    parser.add_argument(
        "--eval_type",
        type=str,
        default="eval",
        help="kept for compatibility with util.get_model_files; ignored by this script",
    )

    # Hint config (global; YAML controls only solver params).
    parser.add_argument(
        "--hint_mode",
        choices=["obj", "int_lp", "auto"],
        default="obj",
        help="obj: full copy; int_lp: ints+LP-repair; auto: try obj then fallback to int_lp",
    )
    parser.add_argument("--hint_k", type=int, default=10)
    parser.add_argument("--hint_lp_seconds", type=float, default=5.0)
    parser.add_argument("--hint_cov_min", type=float, default=0.95)

    args = parser.parse_args()

    print("=" * 82, flush=True)
    print(f"Mode = {args.mode} | instance = {args.instance_name} | solver = {args.solver}", flush=True)
    print("=" * 82, flush=True)

    # Discover files & actions
    model_files = get_model_files(args)
    print(
        f"Found {len(model_files)} instance files under DATA_DIR/{args.instance_name}/... (resolved by util.get_model_files)",
        flush=True,
    )
    print("First few files:", [os.path.basename(p) for p in model_files[:5]], flush=True)

    if args.mode == "mab":
        actions, config_files = _actions_fn(args)
        print(f"Loaded {len(actions)} configs from {args.config_name}", flush=True)

    # Dispatch
    if args.mode == "baseline":
        df = run_baseline(args, model_files)
        _save_named(args, df, "baseline_times.csv")
    elif args.mode == "mab":
        actions, config_files = _actions_fn(args)
        df = run_mab(args, actions, config_files, model_files)
        _save_named(args, df, "mab_times.csv")
