import random
from typing import Dict, Any, Callable

import numpy as np
from tqdm import tqdm
from bayes_opt import BayesianOptimization
import optuna
from utils_original import *

# set threshold in router module
def _set_router_threshold(model, threshold):
                for layer in model.layers:
                    if hasattr(layer, "router"):
                        layer.router.threshold = threshold            
            
# ---------------------------------------------------------
# 1. common objective function
# ---------------------------------------------------------

def eval_threshold_objective(
    threshold: float,
    generator,
    examples,
    device,
    task_key: str,
    task_config: Dict[str, Any],
    max_samples: int = 300,
    metric: str = "accuracy",
) -> float:
    """
    Objective function used by NES / Bayesian Optimization / Optuna.
    Given a routing threshold, apply it to the model, evaluate on a subset
    of examples, and return the chosen metric as reward (default: accuracy).
    """

    thr = float(max(0.0, min(1.0, threshold)))
    _set_router_threshold(generator.model, thr)

    # limit evaluation cost
    subset = examples[:max_samples] if max_samples is not None else examples

    # file_path=None -> skip layer-execution logging
    result = task_config[task_key]["evaluate_fn"](
        generator,
        subset,
        device,
        file_path=None,
    )

    if metric == "accuracy":
        reward = float(result.get("accuracy", 0.0))
    elif metric == "ppl":
        ppl = float(result.get("ppl", 0.0))
        reward = -ppl
    else:
        raise ValueError(f"Unsupported metric: {metric}")

    return reward


# ---------------------------------------------------------
# 2. NES based threshold search
# ---------------------------------------------------------

def nes_optimize_threshold(
    generator,
    examples,
    device,
    task_key: str,
    task_config: Dict[str, Any],
    num_iters: int = 10,
    population_size: int = 8,
    init_mean: float = 0.5,
    init_sigma: float = 0.2,
    alpha: float = 0.1,
    sigma_decay: float = 0.95,
    max_samples: int = 300,
    metric: str = "accuracy",
) -> float:
    """
    Natural Evolution Strategies (NES) to optimize router threshold.
    Returns the optimized mean (mu) as the best threshold.
    """

    mu = init_mean
    sigma = init_sigma

    for it in range(num_iters):
        # 1) sample thresholds
        samples = np.random.normal(mu, sigma, size=population_size)
        samples = np.clip(samples, 0.0, 0.5)

        # 2) evaluate reward
        rewards = []
        for thr in samples:
            r = eval_threshold_objective(
                thr,
                generator,
                examples,
                device,
                task_key,
                task_config,
                max_samples=max_samples,
                metric=metric,
            )
            rewards.append(r)
        rewards = np.array(rewards)

        # 3) normalize rewards
        rewards_mean = rewards.mean()
        rewards_std = rewards.std() + 1e-8
        norm_rewards = (rewards - rewards_mean) / rewards_std

        # 4) NES update
        grad_mu = np.sum(norm_rewards * (samples - mu)) / (population_size * sigma)
        mu = mu + alpha * grad_mu
        mu = float(np.clip(mu, 0.0, 0.5))

        # decay sigma
        sigma *= sigma_decay

        print(
            f"[NES iter {it}] mu={mu:.4f}, sigma={sigma:.4f}, "
            f"best_sample_thr={samples[np.argmax(rewards)]:.4f}, "
            f"best_reward={np.max(rewards):.4f}"
        )

    return mu


# ---------------------------------------------------------
# 3. Bayesian Optimization based threshold search
# ---------------------------------------------------------

def bayes_optimize_threshold(
    generator,
    examples,
    device,
    task_key: str,
    task_config: Dict[str, Any],
    max_samples: int = 300,
    metric: str = "accuracy",
    init_points: int = 3,
    n_iter: int = 15,
    random_state: int = 42,
) -> float:
    """
    Bayesian Optimization to find the best router threshold in [0.0, 0.5].
    """

    def bayes_objective(threshold: float) -> float:
        return eval_threshold_objective(
            threshold,
            generator,
            examples,
            device,
            task_key,
            task_config,
            max_samples=max_samples,
            metric=metric,
        )

    pbounds = {"threshold": (0.0, 0.5)}
    optimizer = BayesianOptimization(
        f=bayes_objective,
        pbounds=pbounds,
        verbose=2,
        random_state=random_state,
    )
    optimizer.maximize(init_points=init_points, n_iter=n_iter)

    best_threshold = float(optimizer.max["params"]["threshold"])
    best_reward = float(optimizer.max["target"])
    print(f"[Bayesian] Best threshold: {best_threshold:.4f}, reward={best_reward:.4f}")
    return best_threshold


# ---------------------------------------------------------
# 4. Optuna based threshold search
# ---------------------------------------------------------

def optuna_optimize_threshold(
    generator,
    examples,
    device,
    task_key: str,
    task_config: Dict[str, Any],
    max_samples: int = 300,
    metric: str = "accuracy",
    n_trials: int = 30,
) -> float:
    """
    Optuna (TPE) based search for the best router threshold.
    """

    def objective(trial: optuna.Trial) -> float:
        thr = trial.suggest_float("threshold", 0.0, 0.5)
        return eval_threshold_objective(
            thr,
            generator,
            examples,
            device,
            task_key,
            task_config,
            max_samples=max_samples,
            metric=metric,
        )

    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=n_trials)

    best_threshold = float(study.best_params["threshold"])
    best_reward = float(study.best_value)
    print(f"[Optuna] Best threshold: {best_threshold:.4f}, reward={best_reward:.4f}")
    return best_threshold


# ---------------------------------------------------------
# 5. Grid search based threshold search (coarse + fine)
# ---------------------------------------------------------

def evaluate_best_threshold(
    generator,
    examples,
    device,
    task: str,
    num_samples: int = 300,
    grid_step1: float = 0.01,
    grid_step2: float = 0.05,
) -> float:
    """
    Two-phase grid search for the best threshold:
      1) coarse search with step = grid_step2
      2) fine search around the best coarse threshold with step = grid_step1

    For 'alpaca', the metric is perplexity (ppl, to be minimized).
    For other tasks, the metric is accuracy (to be maximized).
    """

    # metric & optimization direction
    if task == "alpaca":
        metric = "ppl"
        should_maximize = False  # minimize ppl
        print("✅ Task 'alpaca' detected. Optimizing for MINIMUM Perplexity (ppl).")
    else:
        metric = "accuracy"
        should_maximize = True   # maximize accuracy
        print(f"✅ Task '{task}' detected. Optimizing for MAXIMUM {metric}.")

    # eval function map
    eval_fn_map: Dict[str, Callable] = {
        "alpaca": evaluate_alpaca,
        "piqa": evaluate_piqa,
        "boolq": evaluate_boolq,
        "mmlu": evaluate_mmlu,
        "siqa": evaluate_siqa,
        "hellaswag": evaluate_hellaswag,
        "winogrande": evaluate_winogrande,
        "arcc": evaluate_science_multiple_choice,
        "arce": evaluate_science_multiple_choice,
        "openbookqa": evaluate_science_multiple_choice,
    }
    assert task in eval_fn_map, f"No eval function defined for task: {task}"
    eval_fn = eval_fn_map[task]

    # sample subset
    k = min(num_samples, len(examples))
    random.seed(42)
    short_eval = random.sample(examples, k)

    # init best metric
    best_metric = -float("inf") if should_maximize else float("inf")
    best_thr = grid_step2

    # ---- Phase 1: coarse search ----
    coarse_thresholds = np.arange(grid_step2, 0.51, grid_step2)
    for t in tqdm(coarse_thresholds, desc="Coarse search", ncols=100):
        _set_router_threshold(generator.model, t)
        results = eval_fn(generator, short_eval, device, file_path=None, quiet=True)
        metric_val = results.get(metric)

        if metric_val is not None:
            if (should_maximize and metric_val > best_metric) or (
                not should_maximize and metric_val < best_metric
            ):
                best_metric, best_thr = metric_val, t

    print(f"Coarse search best threshold: {best_thr:.3f} ({metric.upper()}: {best_metric:.4f})")

    # ---- Phase 2: fine search ----
    low = max(grid_step1, best_thr - grid_step2)
    high = min(0.50, best_thr + grid_step2)
    fine_thresholds = np.arange(low, high, grid_step1)

    for t in tqdm(fine_thresholds, desc="Fine search", ncols=100):
        _set_router_threshold(generator.model, t)
        results = eval_fn(generator, short_eval, device, file_path=None, quiet=True)
        metric_val = results.get(metric)

        if metric_val is not None:
            if (should_maximize and metric_val > best_metric) or (
                not should_maximize and metric_val < best_metric
            ):
                best_metric, best_thr = metric_val, t

    print(f"\n🏆 Best threshold found: {best_thr:.4f} with {metric.upper()} {best_metric:.4f}")
    return best_thr