"""
AutoQRA: Automated two‑phase framework for joint per‑layer precision (q_l)
and adaptation rank (r_l) allocation under a memory budget.

Phase I: multi‑fidelity evolutionary search to approximate Pareto front.
Phase II: local Bayesian optimization to pick a single operating point.

This implementation is faithful to the method description while remaining
modular so users can plug in their own LF/HF evaluation functions.

Quick start (proxy evaluation):
  python -m qwen_lora_importance.autoqra.autoqra \
    --num_layers 28 \
    --bits 2 4 8 \
    --ranks 0 4 8 16 32 \
    --budget_bytes 1.2e9 \
    --importance_json qwen_lora_importance/results_gradients/grads_generation_inference_unfreeze.json \
    --phase1_generations 10 --phase1_pop 40 --phase1_promote 6 \
    --phase2_alpha 0.6

Outputs are written to ./results_autoqra/ as JSON artifacts.
"""

from __future__ import annotations

import argparse
import hashlib
import json
import math
import random
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple

import numpy as np
from sklearn.neural_network import MLPRegressor
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import Matern, WhiteKernel, ConstantKernel
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


# =============================
# Utilities
# =============================

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)


def normalize_minmax(x: np.ndarray) -> np.ndarray:
    if x.size == 0:
        return x
    vmin, vmax = float(np.min(x)), float(np.max(x))
    if vmax == vmin:
        return np.zeros_like(x)
    return (x - vmin) / (vmax - vmin)


def non_dominated_sort_constrained(
    points: List[Tuple[float, float]],
    feasible: List[bool]
) -> List[List[int]]:
    """Constrained NSGA-II fast non-dominated sorting per paper.

    Constrained domination rules:
    1. Feasible candidates dominate infeasible ones
    2. Among feasible candidates: standard Pareto dominance
    3. Among infeasible candidates: lower constraint violation is preferred

    Args:
        points: List of (obj1, obj2) tuples for minimization
        feasible: List of booleans indicating if candidate is feasible

    Returns:
        List of fronts, where each front is a list of indices
    """
    n = len(points)
    S = [[] for _ in range(n)]
    n_dom = [0] * n
    fronts: List[List[int]] = [[]]

    def dominates_constrained(p: int, q: int) -> bool:
        """Check if p dominates q under constrained domination."""
        p_feasible, q_feasible = feasible[p], feasible[q]

        # Rule 1: Feasible dominates infeasible
        if p_feasible and not q_feasible:
            return True
        if not p_feasible and q_feasible:
            return False

        # Rule 2 & 3: Both feasible or both infeasible - use Pareto dominance
        # For minimization: p dominates q if p <= q in all objectives and p < q in at least one
        p_better_or_equal = points[p][0] <= points[q][0] and points[p][1] <= points[q][1]
        p_strictly_better = points[p][0] < points[q][0] or points[p][1] < points[q][1]
        return p_better_or_equal and p_strictly_better

    for p in range(n):
        for q in range(n):
            if p == q:
                continue
            if dominates_constrained(p, q):
                S[p].append(q)
            elif dominates_constrained(q, p):
                n_dom[p] += 1
        if n_dom[p] == 0:
            fronts[0].append(p)

    i = 0
    while fronts[i]:
        next_front: List[int] = []
        for p in fronts[i]:
            for q in S[p]:
                n_dom[q] -= 1
                if n_dom[q] == 0:
                    next_front.append(q)
        i += 1
        fronts.append(next_front)
    fronts.pop()  # last is empty
    return fronts


def non_dominated_sort(points: List[Tuple[float, float]]) -> List[List[int]]:
    """NSGA-II style fast non-dominated sorting for 2D objectives (backward compatibility).

    For unconstrained sorting, all candidates are considered feasible.
    """
    feasible = [True] * len(points)
    return non_dominated_sort_constrained(points, feasible)


def crowding_distance(points: List[Tuple[float, float]], indices: List[int]) -> Dict[int, float]:
    if not indices:
        return {}
    # For each objective, sort and assign distances
    d = {i: 0.0 for i in indices}
    for m in range(2):  # two objectives
        sorted_idx = sorted(indices, key=lambda i: points[i][m])
        d[sorted_idx[0]] = d[sorted_idx[-1]] = float("inf")
        vmin = points[sorted_idx[0]][m]
        vmax = points[sorted_idx[-1]][m]
        denom = (vmax - vmin) if vmax != vmin else 1.0
        for j in range(1, len(sorted_idx) - 1):
            prev_v = points[sorted_idx[j - 1]][m]
            next_v = points[sorted_idx[j + 1]][m]
            d[sorted_idx[j]] += (next_v - prev_v) / denom
    return d


def hypervolume_2d(pareto_points: List[Tuple[float, float]], ref: Tuple[float, float]) -> float:
    """Hypervolume for 2D minimization. Requires non-dominated points."""
    if not pareto_points:
        return 0.0
    pts = sorted(pareto_points, key=lambda x: x[0])  # sort by first obj
    hv = 0.0
    prev_x = ref[0]
    prev_y = ref[1]
    for x, y in reversed(pts):
        width = max(0.0, prev_x - x)  # clamp to avoid negative
        height = max(0.0, prev_y - y)
        if width > 0 and height > 0:
            hv += width * height
        prev_x = x  # FIXED: update prev_x to current x
        prev_y = min(prev_y, y)
    return hv


# =============================
# Problem definition
# =============================

@dataclass
class AutoQRAConfig:
    num_layers: int
    Q: List[int]
    R: List[int]
    lora_precision_bits: int = 16  # p_r
    layer_param_bytes: Optional[List[int]] = None  # |W_l| params -> bytes per q-bit via q/8
    lora_params_per_rank: Optional[List[int]] = None  # per‑layer params per rank (|A|+|B|)/r
    # search
    seed: int = 42
    budget_bytes: Optional[float] = None
    # mixed-quant evaluation (optional)
    base_model_id: Optional[str] = None
    task: str = "generation"


class Importance:
    """Orthogonal sensitivity profiles for quantization and adaptation.

    Per paper Section 3.1: I_q(ℓ) measures quantization sensitivity (for bit allocation)
    and I_r(ℓ) measures task learnability (for rank allocation). These are SEPARATE signals
    because layers requiring high precision do not necessarily require high rank.
    """
    def __init__(self, bb: np.ndarray, lr: np.ndarray, w_bb: float = 0.5, w_lr: float = 0.5):
        # Normalize to [0,1] for use in allocation
        self.I_q = normalize_minmax(bb)  # Quantization sensitivity (backbone)
        self.I_r = normalize_minmax(lr)  # Adaptation capacity (LoRA rank)
        # Legacy combined score for backward compatibility
        self.score = w_bb * self.I_q + w_lr * self.I_r

    @classmethod
    def from_json(cls, path: Path, num_layers: int, w_bb: float = 0.5, w_lr: float = 0.5):
        with open(path, "r") as f:
            j = json.load(f)
        bb = np.zeros((num_layers,), dtype=np.float32)
        lr = np.zeros((num_layers,), dtype=np.float32)
        bb_dict = j.get("backbone_metric_per_layer", {})
        lr_dict = j.get("lora_metric_per_layer", {})
        for i in range(num_layers):
            bb[i] = float(bb_dict.get(str(i), bb_dict.get(i, 0.0)))
            lr[i] = float(lr_dict.get(str(i), lr_dict.get(i, 0.0)))
        return cls(bb, lr, w_bb=w_bb, w_lr=w_lr)


class MemoryModel:
    """Memory accounting per Eq. 1-3 in paper.

    M(C) = Σ_ℓ (m^W_ℓ(q_ℓ) + m^A_ℓ(r_ℓ) + m^meta_ℓ(q_ℓ))

    where:
    - m^W_ℓ: quantized backbone storage = N(W_ℓ) * q_ℓ / 8
    - m^A_ℓ: LoRA adapter storage = N(A_ℓ, r_ℓ) * p_r / 8
    - m^meta_ℓ: quantization metadata (scales, zero-points)
    """
    def __init__(self, cfg: AutoQRAConfig, block_size: int = 128, meta_precision_bits: int = 16):
        self.cfg = cfg
        self.block_size = block_size
        self.meta_precision_bits = meta_precision_bits
        nl = cfg.num_layers
        # Defaults if not provided: uniform placeholders
        if cfg.layer_param_bytes is None:
            # Assume same params per layer; scale to 1.0 unit per layer at 8 bits
            self.layer_params = np.ones((nl,), dtype=np.float64) * 1e6  # number of params
        else:
            self.layer_params = np.array(cfg.layer_param_bytes, dtype=np.float64)  # actually params count
        if cfg.lora_params_per_rank is None:
            self.lora_params_per_rank = np.ones((nl,), dtype=np.float64) * 1e5
        else:
            self.lora_params_per_rank = np.array(cfg.lora_params_per_rank, dtype=np.float64)

    def layer_memory_bytes(self, l: int, q_l: int, r_l: int) -> float:
        """Compute layer-wise memory per Eq. 2-3.

        m^W_ℓ(q_ℓ) = N(W_ℓ) * q_ℓ / 8
        m^A_ℓ(r_ℓ) = N(A_ℓ, r_ℓ) * p_r / 8
        m^meta_ℓ(q_ℓ) = #blocks * (scale + zero_point) storage
        """
        # Quantized backbone (Eq. 2)
        m_W = self.layer_params[l] * (q_l / 8.0)

        # LoRA adapters (Eq. 3)
        m_A = (self.lora_params_per_rank[l] * r_l) * (self.cfg.lora_precision_bits / 8.0)

        # Quantization metadata: per-block scale and zero-point
        # For block-wise quantization with block_size elements per block
        num_blocks = int(np.ceil(self.layer_params[l] / self.block_size))
        # Each block needs: 1 scale (FP16) + 1 zero-point (typically same precision as scale)
        bytes_per_block = 2 * (self.meta_precision_bits / 8.0)  # scale + zero-point
        m_meta = num_blocks * bytes_per_block

        return m_W + m_A + m_meta

    def total_memory_bytes(self, q: Sequence[int], r: Sequence[int]) -> float:
        """Total memory M(C) = Σ_ℓ m_ℓ(C) per Eq. 1."""
        return float(sum(self.layer_memory_bytes(i, q[i], r[i]) for i in range(self.cfg.num_layers)))


class ConfigEncoding:
    def __init__(self, Q: List[int], R: List[int]):
        self.Q = sorted(Q)
        self.R = sorted(R)
        self.q_min, self.q_max = self.Q[0], self.Q[-1]
        self.r_min, self.r_max = self.R[0], self.R[-1]

    def s_q(self, q: int) -> float:
        return (q - self.q_min) / (self.q_max - self.q_min) if self.q_max > self.q_min else 0.0

    def s_r(self, r: int) -> float:
        return (r - self.r_min) / (self.r_max - self.r_min) if self.r_max > self.r_min else 0.0

    def round_Q(self, v: float) -> int:
        return min(self.Q, key=lambda x: abs(x - v))

    def round_R(self, v: float) -> int:
        return min(self.R, key=lambda x: abs(x - v))


# =============================
# Evaluation (LF/HF)
# =============================

class ProxyEvaluator:
    """Cheap low-fidelity proxy using importance-weighted resource coverage.

    P_low(C) ~ sum_l I(l) * [ s_q(q_l) + s_r(r_l) ] / 2
    Penalize memory budget violation with a soft penalty.
    Optionally add bonus for hitting target average bits.
    """

    def __init__(self, cfg: AutoQRAConfig, importance: Importance, mem: MemoryModel, enc: ConfigEncoding, target_avg_bits: float = None):
        self.cfg = cfg
        self.I = importance.score.astype(np.float64)
        self.mem = mem
        self.enc = enc
        self.target_avg_bits = target_avg_bits  # Optional target for average bits

    def evaluate(self, q: Sequence[int], r: Sequence[int], budget_bytes: Optional[float]) -> Tuple[float, float]:
        nl = self.cfg.num_layers
        s = 0.0
        for l in range(nl):
            s += float(self.I[l]) * (self.enc.s_q(q[l]) + self.enc.s_r(r[l])) * 0.5
        s = s / max(1, nl)

        # Add bonus for matching target average bits (if specified)
        if self.target_avg_bits is not None:
            avg_bits = sum(q) / len(q)
            # Gaussian-like bonus centered at target, max bonus = 0.5
            bit_diff = abs(avg_bits - self.target_avg_bits)
            bit_bonus = 0.5 * np.exp(-0.5 * (bit_diff / 1.0) ** 2)
            s = s + bit_bonus

        M = self.mem.total_memory_bytes(q, r)
        if budget_bytes is not None and M > budget_bytes:
            # quadratic penalty for violating budget
            over = (M - budget_bytes) / max(budget_bytes, 1.0)
            s = s - 2.0 * (over ** 2)
        return s, M


class RealTaskEvaluator:
    """Execute a lightweight SFT → quant → lm-eval pipeline to obtain true metrics."""

    def __init__(
        self,
        preset: str,
        dataset: str,
        lf_sample_ratio: float = 0.05,
        lf_epochs: float = 0.1,
        hf_sample_ratio: float = 0.2,
        hf_epochs: float = 0.5,
        eval_task: str = "",
        eval_shots: int = 0,
        output_root: Path = None,
        load_in_4bit: bool = True,
    ):
        if not preset:
            raise ValueError("Preset must be provided when using real-task evaluation.")
        self.preset = preset
        self.dataset = dataset
        # LF (Low-Fidelity): Fast, cheap evaluation
        self.lf_sample_ratio = lf_sample_ratio
        self.lf_epochs = lf_epochs
        # HF (High-Fidelity): Thorough, expensive evaluation
        self.hf_sample_ratio = hf_sample_ratio
        self.hf_epochs = hf_epochs
        self.eval_task = eval_task
        self.eval_shots = eval_shots
        self.load_in_4bit = load_in_4bit
        self.base_dir = Path(output_root) if output_root else Path("results_real_eval")
        self.base_dir.mkdir(parents=True, exist_ok=True)
        self.python = sys.executable or "python"
        self.cache: Dict[str, Dict[str, float]] = {}
        self.index_path = self.base_dir / "candidates_index.json"
        if self.index_path.exists():
            try:
                self.dir_index: Dict[str, str] = json.loads(self.index_path.read_text())
            except Exception:
                self.dir_index = {}
        else:
            self.dir_index = {}

    def _hash_config(self, q: Sequence[int], r: Sequence[int]) -> str:
        payload = json.dumps({"q": list(q), "r": list(r)}, sort_keys=True)
        return hashlib.sha1(payload.encode("utf-8")).hexdigest()

    def _run(self, cmd: List[str], cwd: Optional[Path] = None):
        subprocess.run(cmd, check=True, cwd=str(cwd) if cwd else None)

    def _extract_accuracy(self, eval_json: Path) -> float:
        data = json.loads(eval_json.read_text())
        task_res = data.get("results", {}).get(self.eval_task, {})
        for key in ("acc", "acc_norm", "acc,none", "acc_norm,none", "f1", "exact_match"):
            if key in task_res:
                return float(task_res[key])
        raise RuntimeError(f"Could not find accuracy metric for task '{self.eval_task}' in {eval_json}")

    def _format_dir_name(
        self,
        q: Sequence[int],
        r: Sequence[int],
        key: str,
        generation: Optional[int],
        cand_idx: Optional[int],
        stage: str,
    ) -> str:
        uniq_q = "-".join(str(x) for x in sorted(set(q)))
        uniq_r = "-".join(str(x) for x in sorted(set(r)))
        avg_q = sum(q) / max(1, len(q))
        avg_r = sum(r) / max(1, len(r))
        gen_tag = "gen{:02d}".format(generation if generation is not None and generation >= 0 else 0)
        idx_tag = f"cand{cand_idx:02d}" if cand_idx is not None else "candXX"
        stage_tag = stage or ("init" if generation is None or generation < 0 else "evo")
        base = (
            f"cand_{gen_tag}_{idx_tag}_{stage_tag}_"
            f"q{uniq_q}_r{uniq_r}_avgq{avg_q:.1f}_avgr{avg_r:.1f}_{key[:6]}"
        )
        sanitized = base.replace(".", "p")
        name = sanitized
        suffix = 1
        while (self.base_dir / name).exists():
            name = f"{sanitized}_{suffix}"
            suffix += 1
        return name

    def _save_dir_index(self):
        try:
            with open(self.index_path, "w") as f:
                json.dump(self.dir_index, f, indent=2)
        except Exception:
            pass

    def evaluate(
        self,
        q: Sequence[int],
        r: Sequence[int],
        high_fidelity: bool = False,
        generation: Optional[int] = None,
        cand_idx: Optional[int] = None,
        stage: str = "",
    ) -> Tuple[float, float]:
        """
        Evaluate a (q, r) configuration with either LF or HF training.

        Args:
            q: Quantization bit-widths per layer
            r: LoRA ranks per layer
            high_fidelity: If True, use HF params (more training); else use LF params
            generation, cand_idx, stage: Metadata for directory naming

        Returns:
            (performance, memory) tuple
        """
        # Select training parameters based on fidelity level
        if high_fidelity:
            sample_ratio = self.hf_sample_ratio
            epochs = self.hf_epochs
            fidelity_suffix = "_HF"
        else:
            sample_ratio = self.lf_sample_ratio
            epochs = self.lf_epochs
            fidelity_suffix = "_LF"

        # Separate cache for LF and HF
        base_key = self._hash_config(q, r)
        cache_key = base_key + fidelity_suffix

        if cache_key in self.cache:
            return self.cache[cache_key]["perf"], self.cache[cache_key]["mem"]

        # Directory naming includes fidelity level
        dir_name = self.dir_index.get(cache_key)
        if dir_name is None:
            dir_name = self._format_dir_name(q, r, base_key, generation, cand_idx, stage)
            dir_name = dir_name + fidelity_suffix.lower()  # Append _hf or _lf to dir name
            self.dir_index[cache_key] = dir_name
            self._save_dir_index()

        work_dir = self.base_dir / dir_name
        metrics_path = work_dir / "metrics.json"
        if metrics_path.exists():
            metrics = json.loads(metrics_path.read_text())
            self.cache[cache_key] = metrics
            return metrics["perf"], metrics["mem"]

        work_dir.mkdir(parents=True, exist_ok=True)
        config_path = work_dir / "qra_config.json"
        with open(config_path, "w") as f:
            json.dump({"q": list(q), "r": list(r)}, f, indent=2)

        train_dir = work_dir / "sft"
        train_cmd = [
            self.python,
            "qwen_lora_importance/train_autoqra_sft.py",
            "--preset",
            self.preset,
            "--dataset",
            self.dataset,
            "--sample_ratio",
            str(sample_ratio),  # Use selected fidelity params
            "--epochs",
            str(epochs),        # Use selected fidelity params
            "--qra_config",
            str(config_path),
            "--output_dir",
            str(train_dir),
        ]
        if self.load_in_4bit:
            train_cmd.append("--load_in_4bit")
        self._run(train_cmd)

        profile_path = train_dir / "train_profile.json"
        if not profile_path.exists():
            raise RuntimeError(f"Expected {profile_path} after training but file not found.")
        profile = json.loads(profile_path.read_text())
        mem = float(profile.get("peak_mem_bytes", 0.0))

        quant_dir = work_dir / "quant"
        quant_cmd = [
            self.python,
            "qwen_lora_importance/autoqra_post_quant.py",
            "--adapter_dir",
            str(train_dir / "final_model"),
            "--qra_config",
            str(config_path),
            "--preset",
            self.preset,
            "--per_channel",
            "--merge",
            "--out_dir",
            str(quant_dir),
        ]
        self._run(quant_cmd)

        eval_out = work_dir / "eval_results.json"
        eval_cmd = [
            self.python,
            "qwen_lora_importance/experiments/eval_tasks.py",
            "--model_path",
            str(quant_dir),
            "--tasks",
            self.eval_task,
            "--shots",
            str(self.eval_shots),
            "--out",
            str(eval_out),
        ]
        self._run(eval_cmd)
        perf = self._extract_accuracy(eval_out)

        metrics = {"perf": perf, "mem": mem}
        with open(metrics_path, "w") as f:
            json.dump(metrics, f, indent=2)
        self.cache[cache_key] = metrics  # Use cache_key, not key
        return perf, mem


class SurrogatePromotion:
    """Predict HF from LF + descriptors for promotion decisions."""

    def __init__(self):
        self.model = MLPRegressor(hidden_layer_sizes=(64, 32), random_state=0, max_iter=500)
        self.X: List[List[float]] = []
        self.y: List[float] = []

    def _features(self, plow: float, mem: float, q: Sequence[int], r: Sequence[int], enc: ConfigEncoding, I: np.ndarray) -> List[float]:
        q_hist = [q.count(v) / len(q) for v in enc.Q]
        r_hist = [r.count(v) / len(r) for v in enc.R]
        # importance‑weighted coverage
        q_cov = float(np.dot(I, np.array([enc.s_q(x) for x in q])) / len(q))
        r_cov = float(np.dot(I, np.array([enc.s_r(x) for x in r])) / len(r))
        return [plow, mem, q_cov, r_cov] + q_hist + r_hist

    def update(self, plow: float, phigh: float, mem: float, q: Sequence[int], r: Sequence[int], enc: ConfigEncoding, I: np.ndarray):
        self.X.append(self._features(plow, mem, q, r, enc, I))
        self.y.append(phigh)
        if len(self.y) >= 10:  # train after collecting a few points
            self.model.fit(np.array(self.X), np.array(self.y))

    def predict(self, plow: float, mem: float, q: Sequence[int], r: Sequence[int], enc: ConfigEncoding, I: np.ndarray) -> float:
        if len(self.y) < 10:
            # fallback: identity
            return plow
        x = self._features(plow, mem, q, r, enc, I)
        return float(self.model.predict([x])[0])


# =============================
# Improved Surrogate Model (Paper Spec: 2-layer MLP + GELU)
# =============================

class GELUSurrogateNet(nn.Module):
    """
    2-layer MLP with GELU activation for LF→HF prediction.
    
    Per paper Appendix A: "A 2-layer MLP with GELU is retrained each generation
    on accumulated (Plow, M, Phigh) tuples, with standardized inputs and early
    stopping on a small validation split."
    """
    
    def __init__(self, input_dim: int, hidden_dims: Tuple[int, int] = (64, 32)):
        super().__init__()
        h1, h2 = hidden_dims
        self.net = nn.Sequential(
            nn.Linear(input_dim, h1),
            nn.GELU(),
            nn.Linear(h1, h2),
            nn.GELU(),
            nn.Linear(h2, 1),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x).squeeze(-1)


class SurrogateMLPPromotion:
    """Improved surrogate for LF→HF prediction per paper Eq. 7.

    Features:
    - 2-layer MLP with GELU activation
    - Huber loss (ρ) to mitigate low-budget outliers
    - L2 regularization: θ_s* = argmin Σ ρ(Φ_s(x_i) - P(C_i; b_S)) + λ||θ||²
    - Standardized inputs (z-score normalization)
    - Early stopping on validation split
    """

    def __init__(
        self,
        hidden_dims: Tuple[int, int] = (64, 32),
        val_fraction: float = 0.2,
        patience: int = 10,
        lr: float = 1e-3,
        max_epochs: int = 500,
        min_samples: int = 10,
        huber_delta: float = 1.0,  # Huber loss parameter
        l2_lambda: float = 0.01,   # L2 regularization
    ):
        """
        Args:
            hidden_dims: Tuple of hidden layer sizes (default: (64, 32))
            val_fraction: Fraction of data for validation (for early stopping)
            patience: Early stopping patience
            lr: Learning rate for Adam optimizer
            max_epochs: Maximum training epochs
            min_samples: Minimum samples before training starts
            huber_delta: Threshold for Huber loss (robustness to outliers)
            l2_lambda: L2 regularization coefficient
        """
        self.hidden_dims = hidden_dims
        self.val_fraction = val_fraction
        self.patience = patience
        self.lr = lr
        self.max_epochs = max_epochs
        self.min_samples = min_samples
        self.huber_delta = huber_delta
        self.l2_lambda = l2_lambda

        self.scaler = StandardScaler()
        self.model: Optional[GELUSurrogateNet] = None
        self.input_dim: Optional[int] = None

        self.X: List[List[float]] = []
        self.y: List[float] = []
        self.is_fitted = False
    
    def _build_model(self, input_dim: int) -> GELUSurrogateNet:
        """Build the PyTorch model."""
        self.input_dim = input_dim
        self.model = GELUSurrogateNet(input_dim, self.hidden_dims)
        return self.model
    
    def _features(
        self,
        plow: float,
        mem: float,
        q: Sequence[int],
        r: Sequence[int],
        enc: ConfigEncoding,
        I: np.ndarray
    ) -> List[float]:
        """Extract feature vector (same as original SurrogatePromotion)."""
        q_list = list(q)
        r_list = list(r)
        q_hist = [q_list.count(v) / len(q_list) for v in enc.Q]
        r_hist = [r_list.count(v) / len(r_list) for v in enc.R]
        # importance-weighted coverage
        q_cov = float(np.dot(I, np.array([enc.s_q(x) for x in q])) / len(q))
        r_cov = float(np.dot(I, np.array([enc.s_r(x) for x in r])) / len(r))
        return [plow, mem, q_cov, r_cov] + q_hist + r_hist
    
    def update(
        self,
        plow: float,
        phigh: float,
        mem: float,
        q: Sequence[int],
        r: Sequence[int],
        enc: ConfigEncoding,
        I: np.ndarray
    ):
        """Add a new training sample and retrain if enough samples."""
        self.X.append(self._features(plow, mem, q, r, enc, I))
        self.y.append(phigh)
        
        if len(self.y) >= self.min_samples:
            self._train()
    
    def _train(self):
        """Train the MLP with early stopping on validation split."""
        X = np.array(self.X)
        y = np.array(self.y)
        
        # Standardize inputs
        X_scaled = self.scaler.fit_transform(X)
        
        # Train/val split for early stopping
        if len(X_scaled) < 5:
            X_train, X_val = X_scaled, X_scaled
            y_train, y_val = y, y
        else:
            X_train, X_val, y_train, y_val = train_test_split(
                X_scaled, y, test_size=self.val_fraction, random_state=42
            )
        
        # Build model if not exists
        if self.model is None:
            self._build_model(X.shape[1])
        
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        # Use Huber loss for robustness to outliers per Eq. 7
        criterion = nn.HuberLoss(delta=self.huber_delta)

        best_val_loss = float('inf')
        patience_counter = 0
        best_state = None

        X_train_t = torch.tensor(X_train, dtype=torch.float32)
        y_train_t = torch.tensor(y_train, dtype=torch.float32)
        X_val_t = torch.tensor(X_val, dtype=torch.float32)
        y_val_t = torch.tensor(y_val, dtype=torch.float32)

        for epoch in range(self.max_epochs):
            # Training step
            self.model.train()
            optimizer.zero_grad()
            pred = self.model(X_train_t)
            # Huber loss + L2 regularization: ρ(...) + λ||θ||²
            loss = criterion(pred, y_train_t)
            # Add L2 regularization
            l2_reg = torch.tensor(0., dtype=torch.float32)
            for param in self.model.parameters():
                l2_reg += torch.norm(param, p=2)
            loss = loss + self.l2_lambda * l2_reg

            loss.backward()
            optimizer.step()

            # Validation step
            self.model.eval()
            with torch.no_grad():
                val_pred = self.model(X_val_t)
                val_loss = criterion(val_pred, y_val_t).item()
            
            # Early stopping check
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                best_state = {k: v.clone() for k, v in self.model.state_dict().items()}
            else:
                patience_counter += 1
                if patience_counter >= self.patience:
                    break  # Early stopping triggered
        
        # Restore best model
        if best_state is not None:
            self.model.load_state_dict(best_state)
        
        self.is_fitted = True
    
    def predict(
        self,
        plow: float,
        mem: float,
        q: Sequence[int],
        r: Sequence[int],
        enc: ConfigEncoding,
        I: np.ndarray
    ) -> float:
        """Predict HF performance from LF and config descriptors."""
        if not self.is_fitted or len(self.y) < self.min_samples:
            # Fallback: identity (return plow as-is)
            return plow
        
        x = self._features(plow, mem, q, r, enc, I)
        x_scaled = self.scaler.transform([x])
        x_t = torch.tensor(x_scaled, dtype=torch.float32)
        
        self.model.eval()
        with torch.no_grad():
            pred = self.model(x_t)
        return float(pred.item())


# =============================
# Search operators
# =============================

def repair_to_budget(
    q: List[int],
    r: List[int],
    enc: ConfigEncoding,
    I_q: np.ndarray,
    I_r: np.ndarray,
    mem: MemoryModel,
    budget_bytes: float,
    epsilon: float = 1e-8
) -> Tuple[List[int], List[int]]:
    """Deterministic repair procedure per Eq. 4-5 in paper.

    While M(C) > B_max, iteratively apply the downgrade that minimizes
    sensitivity per saved memory:

        (ℓ*, t*) = argmin_{ℓ,t∈{q,r}} (I_t(ℓ) + ε) / ΔM_t(ℓ)

    where ΔM_t(ℓ) is the memory saved by downgrading q_ℓ or r_ℓ by one step.
    """
    q_out = q.copy()
    r_out = r.copy()

    def get_lower(val: int, choices: List[int]) -> Optional[int]:
        """Get next lower value in sorted choices, or None if at minimum."""
        try:
            idx = choices.index(val)
            return choices[idx - 1] if idx > 0 else None
        except ValueError:
            return None

    while mem.total_memory_bytes(q_out, r_out) > budget_bytes:
        best_cost = float('inf')
        best_layer = None
        best_type = None  # 'q' or 'r'
        best_new_val = None

        # Evaluate all possible downgrades
        for l in range(len(q_out)):
            # Try downgrading q_l
            q_lower = get_lower(q_out[l], enc.Q)
            if q_lower is not None:
                # Compute ΔM_q(ℓ) = M(C) - M(C↓q_ℓ)
                q_temp = q_out.copy()
                q_temp[l] = q_lower
                delta_M_q = mem.total_memory_bytes(q_out, r_out) - mem.total_memory_bytes(q_temp, r_out)

                if delta_M_q > 0:  # Must save memory
                    # Cost = (I_q(ℓ) + ε) / ΔM_q(ℓ)
                    cost_q = (I_q[l] + epsilon) / delta_M_q
                    if cost_q < best_cost:
                        best_cost = cost_q
                        best_layer = l
                        best_type = 'q'
                        best_new_val = q_lower

            # Try downgrading r_l
            r_lower = get_lower(r_out[l], enc.R)
            if r_lower is not None:
                # Compute ΔM_r(ℓ) = M(C) - M(C↓r_ℓ)
                r_temp = r_out.copy()
                r_temp[l] = r_lower
                delta_M_r = mem.total_memory_bytes(q_out, r_out) - mem.total_memory_bytes(q_out, r_temp)

                if delta_M_r > 0:  # Must save memory
                    # Cost = (I_r(ℓ) + ε) / ΔM_r(ℓ)
                    cost_r = (I_r[l] + epsilon) / delta_M_r
                    if cost_r < best_cost:
                        best_cost = cost_r
                        best_layer = l
                        best_type = 'r'
                        best_new_val = r_lower

        # Apply best downgrade
        if best_layer is None:
            # Cannot downgrade further - already at minimum everywhere
            break

        if best_type == 'q':
            q_out[best_layer] = best_new_val
        else:  # best_type == 'r'
            r_out[best_layer] = best_new_val

    return q_out, r_out


def warm_start_from_importance(
    enc: ConfigEncoding,
    I_q: np.ndarray,
    I_r: np.ndarray,
    tau_q: str = "identity",
    tau_r: str = "identity"
) -> Tuple[List[int], List[int]]:
    """Warm start per Eq. 6 in paper.

    q_ℓ^(0) = ⌊τ_q(Ĩ_q(ℓ))⌉_Q
    r_ℓ^(0) = ⌊τ_r(Ĩ_r(ℓ))⌉_R

    where Ĩ denotes normalized importance and τ is a shaping function.
    """
    nl = len(I_q)

    # Shaping functions
    def shape(x: float, kind: str) -> float:
        if kind == "identity":
            return x
        if kind == "sqrt":
            return math.sqrt(max(0.0, x))
        if kind == "square":
            return x * x
        return x

    q = []
    r = []
    for l in range(nl):
        # Apply shaping to normalized importance
        iq_shaped = shape(float(I_q[l]), tau_q)
        ir_shaped = shape(float(I_r[l]), tau_r)

        # Map to continuous bit/rank range
        q_cont = enc.q_min + (enc.q_max - enc.q_min) * iq_shaped
        r_cont = enc.r_min + (enc.r_max - enc.r_min) * ir_shaped

        # Round to nearest discrete value
        q.append(enc.round_Q(q_cont))
        r.append(enc.round_R(r_cont))

    return q, r


def jitter_configuration(q: List[int], r: List[int], enc: ConfigEncoding, budget_bytes: Optional[float], mem: MemoryModel, max_jitter: int = 4) -> Tuple[List[int], List[int]]:
    q2 = q.copy()
    r2 = r.copy()
    nl = len(q)
    for _ in range(max_jitter):
        l = random.randrange(nl)
        if random.random() < 0.5:
            q2[l] = random.choice(enc.Q)
        else:
            r2[l] = random.choice(enc.R)
        if budget_bytes is not None and mem.total_memory_bytes(q2, r2) > budget_bytes and random.random() < 0.5:
            # try to reduce elsewhere to satisfy budget
            l2 = random.randrange(nl)
            if random.random() < 0.5:
                q2[l2] = enc.Q[0]
            else:
                r2[l2] = enc.R[0]
    return q2, r2


def mutate_importance_guided(
    q: List[int],
    r: List[int],
    enc: ConfigEncoding,
    I_q: np.ndarray,
    I_r: np.ndarray,
    gamma: float,
    mem: MemoryModel,
    budget_bytes: Optional[float],
    use_coupled: bool = True
) -> Tuple[List[int], List[int]]:
    """Sensitivity-guided mutation with budget-balanced coupling.

    Two mutation types:
    1. Sensitivity-guided single mutation: Select layer ℓ with Pr_t(ℓ) ∝ I_t(ℓ)^γ
    2. Budget-balanced coupled mutation: Apply memory-increasing edit, then
       compensate with memory-decreasing edits on other layers to restore feasibility.

    This global compensation avoids scale-mismatch of within-layer iso-memory pairing.
    """
    q2 = q.copy()
    r2 = r.copy()
    nl = len(q)

    # Budget-balanced coupled mutation with probability 0.3
    if use_coupled and budget_bytes is not None and random.random() < 0.3:
        # Primary memory-increasing edit on important layer
        probs_q = np.power(np.maximum(I_q, 1e-8), gamma)
        probs_q = probs_q / probs_q.sum()
        l_primary = int(np.random.choice(np.arange(nl), p=probs_q))

        # Try to increase either q or r at l_primary
        increase_type = 'q' if random.random() < 0.5 else 'r'

        if increase_type == 'q':
            idx = enc.Q.index(q2[l_primary])
            if idx < len(enc.Q) - 1:
                q2[l_primary] = enc.Q[idx + 1]
        else:  # increase_type == 'r'
            probs_r = np.power(np.maximum(I_r, 1e-8), gamma)
            probs_r = probs_r / probs_r.sum()
            l_primary = int(np.random.choice(np.arange(nl), p=probs_r))
            idx = enc.R.index(r2[l_primary])
            if idx < len(enc.R) - 1:
                r2[l_primary] = enc.R[idx + 1]

        # Compensating memory-decreasing edits to restore feasibility
        # Apply repair procedure to bring back under budget
        if mem.total_memory_bytes(q2, r2) > budget_bytes:
            q2, r2 = repair_to_budget(q2, r2, enc, I_q, I_r, mem, budget_bytes)

    else:
        # Standard sensitivity-guided single mutation
        # Choose layer with prob ~ I^gamma (use I_q for both q and r selection)
        probs = np.power(np.maximum(I_q + I_r, 1e-8), gamma)
        probs = probs / probs.sum()
        l = int(np.random.choice(np.arange(nl), p=probs))

        # Determine whether to modify q or r
        if random.random() < 0.5:  # modify q
            # Use I_q to guide direction
            inc_prob = float(I_q[l])
            if random.random() < inc_prob:
                # increase q if possible
                idx = enc.Q.index(q2[l])
                if idx < len(enc.Q) - 1:
                    q2[l] = enc.Q[idx + 1]
            else:
                idx = enc.Q.index(q2[l])
                if idx > 0:
                    q2[l] = enc.Q[idx - 1]
        else:  # modify r
            # Use I_r to guide direction
            inc_prob = float(I_r[l])
            if random.random() < inc_prob:
                idx = enc.R.index(r2[l])
                if idx < len(enc.R) - 1:
                    r2[l] = enc.R[idx + 1]
            else:
                idx = enc.R.index(r2[l])
                if idx > 0:
                    r2[l] = enc.R[idx - 1]

    return q2, r2


def crossover_uniform(q1: List[int], r1: List[int], q2: List[int], r2: List[int]) -> Tuple[List[int], List[int]]:
    nl = len(q1)
    cq = [q1[i] if random.random() < 0.5 else q2[i] for i in range(nl)]
    cr = [r1[i] if random.random() < 0.5 else r2[i] for i in range(nl)]
    return cq, cr


# =============================
# Phase I: Multi‑fidelity evolutionary search
# =============================

class PhaseIEvolution:
    def __init__(
        self,
        cfg: AutoQRAConfig,
        importance: Importance,
        real_eval_params: Optional[Dict] = None,
        lowbit_value: Optional[int] = None,
        max_lowbit_fraction: float = 1.0,
        target_avg_bits: float = None,
    ):
        self.cfg = cfg
        self.importance = importance
        self.enc = ConfigEncoding(cfg.Q, cfg.R)
        self.mem = MemoryModel(cfg)
        self.proxy = ProxyEvaluator(cfg, importance, self.mem, self.enc, target_avg_bits=target_avg_bits)
        self.sur = SurrogateMLPPromotion(hidden_dims=(64, 32), patience=10)
        self.mq_eval = None
        self.real_eval_params = real_eval_params
        self.real_eval: Optional[RealTaskEvaluator] = None
        self.lowbit_value = lowbit_value
        self.max_lowbit_fraction = max(0.0, min(1.0, max_lowbit_fraction))
        self.target_avg_bits = target_avg_bits

    def run(
        self,
        pop_size: int = 40,
        generations: int = 15,
        promote_k: int = 6,
        gamma: float = 1.5,
        hv_epsilon: float = 1e-3,
        hv_window: int = 3,
        ref_point: Optional[Tuple[float, float]] = None,
        # Ablation toggles
        use_warm_start: bool = True,
        use_importance_mutation: bool = True,
        use_coupled_mutation: bool = True,
        use_surrogate_promotion: bool = True,
        multi_fidelity: bool = True,
        lf_eval_mode: str = "proxy",  # or "ptq"
        progress_dir: Optional[Path] = None,
        real_eval_params: Optional[Dict] = None,
    ) -> Dict:
        real_eval_params = real_eval_params or self.real_eval_params
        set_seed(self.cfg.seed)
        nl = self.cfg.num_layers
        I = self.importance.score

        def _enforce_lowbit_cap(q_vec: List[int]) -> List[int]:
            if self.max_lowbit_fraction >= 1.0:
                return q_vec
            low_bit = self.lowbit_value if self.lowbit_value is not None else min(self.enc.Q)
            low_indices = [i for i, bit in enumerate(q_vec) if bit == low_bit]
            allowed = int(math.floor(self.max_lowbit_fraction * len(q_vec)))
            if len(low_indices) <= allowed:
                return q_vec
            higher_choices = sorted([b for b in self.enc.Q if b > low_bit])
            if not higher_choices:
                return q_vec
            promote_bit = higher_choices[0]
            I_vals = self.importance.score
            low_indices.sort(key=lambda idx: I_vals[idx], reverse=True)
            for idx in low_indices[: len(low_indices) - allowed]:
                q_vec[idx] = promote_bit
            return q_vec

        # Warm start population
        if use_warm_start:
            q0, r0 = warm_start_from_importance(self.enc, I)
        else:
            q0 = [random.choice(self.enc.Q) for _ in range(nl)]
            r0 = [random.choice(self.enc.R) for _ in range(nl)]
        q0 = _enforce_lowbit_cap(q0)
        population: List[Tuple[List[int], List[int]]] = []
        population.append((q0, r0))
        while len(population) < pop_size:
            q_jit, r_jit = jitter_configuration(q0, r0, self.enc, self.cfg.budget_bytes, self.mem)
            q_jit = _enforce_lowbit_cap(q_jit)
            population.append((q_jit, r_jit))

        # Evaluate initial population with LF and some HF
        records: List[Dict] = []  # store candidates and evals

        def eval_and_record(q, r, high_fid: bool = False, generation: Optional[int] = None, cand_idx: Optional[int] = None, stage: str = "") -> Dict:
            q = _enforce_lowbit_cap(q.copy())
            # Choose evaluator: proxy, PTQ mixed-quant, or real downstream task.
            if lf_eval_mode == "real_task":
                if self.real_eval is None:
                    if real_eval_params is None:
                        raise ValueError("real_eval_params must be provided when lf_eval_mode='real_task'.")
                    self.real_eval = RealTaskEvaluator(
                        preset=real_eval_params["preset"],
                        dataset=real_eval_params["dataset"],
                        lf_sample_ratio=real_eval_params.get("lf_sample_ratio", 0.05),
                        lf_epochs=real_eval_params.get("lf_epochs", 0.1),
                        hf_sample_ratio=real_eval_params.get("hf_sample_ratio", real_eval_params.get("sample_ratio", 0.2)),
                        hf_epochs=real_eval_params.get("hf_epochs", real_eval_params.get("epochs", 0.5)),
                        eval_task=real_eval_params["eval_task"],
                        eval_shots=real_eval_params["eval_shots"],
                        output_root=Path(real_eval_params["output_root"]),
                        load_in_4bit=real_eval_params.get("load_in_4bit", True),
                    )

                # FIXED: Properly handle LF vs HF evaluation
                if high_fid:
                    # High-fidelity: thorough training
                    perf, M = self.real_eval.evaluate(q, r, high_fidelity=True,
                                                     generation=generation, cand_idx=cand_idx, stage=stage)
                    rec = {"q": q, "r": r, "phigh": perf, "mem": M}
                else:
                    # Low-fidelity: fast training
                    perf, M = self.real_eval.evaluate(q, r, high_fidelity=False,
                                                     generation=generation, cand_idx=cand_idx, stage=stage)
                    rec = {"q": q, "r": r, "plow": perf, "mem": M}

                # Note: Surrogate update happens later when we have both plow and phigh
                # for the same candidate (when promoted from LF to HF)
            elif lf_eval_mode == "ptq" and self.cfg.base_model_id:
                # Lazy init mixed-quant evaluator
                if self.mq_eval is None:
                    # FIXED: Use absolute import instead of relative import
                    try:
                        from autoqra.mixed_quant_eval import MQEvalConfig, MixedQuantEvaluator
                    except ImportError:
                        # Fallback for when running as script
                        import sys
                        # Path is already imported at module level, line 34
                        autoqra_dir = Path(__file__).parent
                        if str(autoqra_dir) not in sys.path:
                            sys.path.insert(0, str(autoqra_dir))
                        from mixed_quant_eval import MQEvalConfig, MixedQuantEvaluator
                    self.mq_eval = MixedQuantEvaluator(MQEvalConfig(
                        model_id=self.cfg.base_model_id,
                        task=self.cfg.task,
                        sample_ratio_lf=0.02,
                        sample_ratio_hf=0.1,
                        per_channel=True,
                    ))
                # FIXED: Separate LF and HF evaluations
                # Always get LF evaluation
                metric_lf = self.mq_eval.evaluate_candidate(q, high_fidelity=False)
                if self.cfg.task == "generation":
                    plow = -metric_lf  # minimize ppl -> maximize -ppl
                else:
                    plow = metric_lf   # maximize acc
                M = self.mem.total_memory_bytes(q, r)
                rec = {"q": q, "r": r, "plow": plow, "mem": M}

                # If high_fid requested, also get HF evaluation
                if high_fid:
                    metric_hf = self.mq_eval.evaluate_candidate(q, high_fidelity=True)
                    if self.cfg.task == "generation":
                        phigh = -metric_hf
                    else:
                        phigh = metric_hf
                    rec["phigh"] = phigh
                    if use_surrogate_promotion:
                        self.sur.update(plow, phigh, M, q, r, self.enc, I)
            else:
                # Proxy mode - no real HF evaluation available
                plow, M = self.proxy.evaluate(q, r, self.cfg.budget_bytes)
                rec = {"q": q, "r": r, "plow": plow, "mem": M}
                if high_fid:
                    # Placeholder: assume phigh ≈ plow with small noise for proxy mode
                    phigh = float(plow + np.random.normal(scale=0.02))
                    rec["phigh"] = phigh
                    if use_surrogate_promotion:
                        self.sur.update(plow, phigh, M, q, r, self.enc, I)
            return rec

        n_hf_total = 0

        for i, (q, r) in enumerate(population):
            rec = eval_and_record(q, r, high_fid=(i < promote_k), generation=-1, cand_idx=i, stage="init")
            records.append(rec)
            if lf_eval_mode == "real_task":
                n_hf_total += 1

        hv_hist: List[float] = []

        import time
        gen_stats: List[Dict] = []
        n_lf_total = 0

        for gen in range(generations):
            t0 = time.time()
            print(f"\n=== 世代 {gen+1}/{generations} ===")

            # Generate offspring via local neighbor search
            # NEW: Each parent explores exactly 5 nearest atomic neighbors
            offsprings: List[Tuple[List[int], List[int]]] = []

            # For each individual in population, generate k nearest neighbors
            neighbors_per_point = 5  # Search 5 neighbors around each point

            for parent_idx, (pq, pr) in enumerate(population):
                print(f"  父代 {parent_idx+1}/{len(population)}: 生成 {neighbors_per_point} 个邻居...")

                # Generate k nearest atomic neighbors (1-edit distance)
                neighbors = generate_k_nearest_atomic_neighbors(pq, pr, self.enc, k=neighbors_per_point)

                # Apply repair to ensure feasibility
                for mq, mr in neighbors:
                    mq, mr = repair_to_budget(mq, mr, self.enc, I_q, I_r, self.mem, self.cfg.budget_bytes)
                    mq = _enforce_lowbit_cap(mq)
                    offsprings.append((mq, mr))

            # Evaluate all offsprings LF
            off_recs = [
                eval_and_record(q, r, high_fid=False, generation=gen, cand_idx=idx, stage="evo")
                for idx, (q, r) in enumerate(offsprings)
            ]
            n_lf_total += len(off_recs)

            # FIXED: Use multi_fidelity flag to control HF evaluation strategy
            # Rank by surrogate predicted HF and promote top-K
            promote_list: List[Dict] = []
            if not multi_fidelity:
                # FIXED: When multi_fidelity=False (--hf_all flag), evaluate ALL offspring at HF
                for rec in off_recs:
                    hf_rec = eval_and_record(rec["q"], rec["r"], high_fid=True, generation=gen, stage="evo_hf")
                    rec["phigh"] = hf_rec["phigh"]
                n_hf_total += len(off_recs)
            else:
                # Multi-fidelity mode: select top-K for HF promotion
                scored = []
                if use_surrogate_promotion:
                    for rec in off_recs:
                        pred = self.sur.predict(rec["plow"], rec["mem"], rec["q"], rec["r"], self.enc, I)
                        scored.append((pred, rec))
                    scored.sort(key=lambda x: x[0], reverse=True)
                    promote_list = [rec for _, rec in scored[:promote_k]]
                else:
                    # fallback: by plow directly
                    off_recs_sorted = sorted(off_recs, key=lambda r: r["plow"], reverse=True)
                    promote_list = off_recs_sorted[:promote_k]

                # FIXED: Re-evaluate promoted candidates at HF for all eval modes
                for rec in promote_list:
                    # Re-evaluate at HF (applies to real_task, PTQ, and proxy modes)
                    hf_rec = eval_and_record(rec["q"], rec["r"], high_fid=True, generation=gen, stage="evo_hf")
                    rec["phigh"] = hf_rec["phigh"]

                    # Update surrogate with (plow, phigh) pair for learning
                    if use_surrogate_promotion and "plow" in rec and "phigh" in rec:
                        self.sur.update(rec["plow"], rec["phigh"], rec["mem"], rec["q"], rec["r"], self.enc, I)

                n_hf_total += len(promote_list)

            records.extend(off_recs)

            # Selection via NSGA-II (use all records, fallback to plow for LF-only)
            # Objective: minimize (-phigh, mem) or (-plow, mem) as fallback
            # FIXED: Include all records to prevent population shrinkage
            evaluated = records  # Use all records instead of only those with phigh
            objs = [(-rec.get("phigh", rec["plow"]), rec["mem"]) for rec in evaluated]
            # Use constrained domination: feasible candidates dominate infeasible ones
            feasible = [rec["mem"] <= (self.cfg.budget_bytes or float('inf')) for rec in evaluated]
            fronts = non_dominated_sort_constrained(objs, feasible)

            # Build new population from best fronts + crowding distance
            new_pop: List[Tuple[List[int], List[int]]] = []
            for f in fronts:
                if len(new_pop) + len(f) <= pop_size:
                    new_pop.extend([(evaluated[i]["q"], evaluated[i]["r"]) for i in f])
                else:
                    cd = crowding_distance(objs, f)
                    f_sorted = sorted(f, key=lambda i: cd[i], reverse=True)
                    needed = pop_size - len(new_pop)
                    new_pop.extend([(evaluated[i]["q"], evaluated[i]["r"]) for i in f_sorted[:needed]])
                    break
            population = new_pop

            # Hypervolume progress (use only HF-evaluated points for tracking)
            hf_evaluated = [i for i, rec in enumerate(evaluated) if "phigh" in rec]
            if hf_evaluated:
                pareto_hf = [i for i in fronts[0] if i in hf_evaluated] if fronts else hf_evaluated
                pareto_pts = [objs[i] for i in pareto_hf]
            else:
                # Fallback: use all Pareto points if no HF evaluations yet
                pareto_pts = [objs[i] for i in fronts[0]] if fronts else objs
            if ref_point is None:
                # FIXED: Build reference point that is worse than all Pareto points
                # obj[0] = -phigh (negative, minimize) -> ref should be less negative (add margin)
                # obj[1] = mem (positive, minimize) -> ref should be larger (scale by 1.1)
                max_cost = max(o[0] for o in objs)  # most negative
                max_mem = max(o[1] for o in objs)
                # For cost (negative): add positive margin to make it less negative (worse)
                ref_cost = max_cost + abs(max_cost) * 0.1 if max_cost < 0 else max_cost * 1.1
                ref = (ref_cost, max_mem * 1.1)  # Keep same order as objs: (cost, mem)
            else:
                ref = ref_point
            hv = hypervolume_2d(pareto_pts, ref)
            hv_hist.append(hv)

            t1 = time.time()
            n_hf_gen = len(promote_list) if lf_eval_mode != "real_task" else len(off_recs)
            gen_stats.append({
                "generation": gen,
                "n_lf": len(off_recs),
                "n_hf": n_hf_gen,
                "time_sec": t1 - t0,
                "hv": hv,
                "promotion_ratio": (n_hf_gen / len(off_recs)) if len(off_recs) else 0.0,
            })

            # Persist progress JSON and figures each generation
            # Resolve progress directory
            prog_dir = progress_dir if progress_dir is not None else Path("./results_autoqra_progress")
            try:
                prog_dir.mkdir(parents=True, exist_ok=True)
            except Exception:
                pass

            # Write progress JSON
            try:
                with open((prog_dir / "progress.json"), "w") as f:
                    json.dump({"hv_hist": hv_hist, "gen_stats": gen_stats}, f, indent=2)
            except Exception:
                pass

            # Save quick figures
            try:
                fig_dir = prog_dir / "figures"
                fig_dir.mkdir(parents=True, exist_ok=True)
                sns.set_style("whitegrid")
                plt.figure(figsize=(6, 4))
                plt.plot(np.arange(len(hv_hist)), hv_hist, marker='o')
                plt.xlabel("Generation"); plt.ylabel("Pareto Hypervolume")
                plt.tight_layout()
                plt.savefig(fig_dir / "hv_curve.png", dpi=200)
                plt.close()

                gens = [g["generation"] for g in gen_stats]
                nlf = [g["n_lf"] for g in gen_stats]
                nhf = [g["n_hf"] for g in gen_stats]
                pr  = [g.get("promotion_ratio", 0.0) for g in gen_stats]
                fig, ax1 = plt.subplots(1, 1, figsize=(6, 4))
                ax1.plot(gens, nlf, label='LF', color='#2E86AB')
                ax1.plot(gens, nhf, label='HF', color='#A23B72')
                ax1.set_xlabel("Generation"); ax1.set_ylabel("#Evaluations")
                ax2 = ax1.twinx()
                ax2.plot(gens, pr, label='Promotion Ratio', color='#33AA33')
                ax2.set_ylabel("Promotion Ratio")
                lines, labels = ax1.get_legend_handles_labels()
                l2, lb2 = ax2.get_legend_handles_labels()
                ax1.legend(lines + l2, labels + lb2, loc='upper right', fontsize=8)
                plt.tight_layout()
                plt.savefig(fig_dir / "eval_counts.png", dpi=200)
                plt.close()
            except Exception:
                pass

            # FIXED: Early stopping based on hypervolume convergence
            # Only stop if hypervolume is improving but the improvement is negligible
            if len(hv_hist) >= hv_window + 1:
                prev = hv_hist[-hv_window - 1]
                curr = hv_hist[-1]
                # Check: (1) prev is valid, (2) curr >= prev (improvement or plateau), (3) improvement < epsilon
                if prev > 0 and curr >= prev and (curr - prev) / prev < hv_epsilon:
                    print(f"\n[Early Stop] Hypervolume converged: improvement {(curr-prev)/prev:.6f} < {hv_epsilon}")
                    break  # converged

        # Collect Pareto set
        evaluated = [rec for rec in records if "phigh" in rec]
        objs = [(-rec.get("phigh", rec["plow"]), rec["mem"]) for rec in evaluated]
        # Use constrained domination for final Pareto set
        feasible = [rec["mem"] <= (self.cfg.budget_bytes or float('inf')) for rec in evaluated]
        fronts = non_dominated_sort_constrained(objs, feasible)
        pareto = [evaluated[i] for i in fronts[0]] if fronts else evaluated

        return {
            "pareto": pareto,
            "all": evaluated,
            "hv_hist": hv_hist,
            "config": {
                "pop_size": pop_size,
                "generations": generations,
                "promote_k": promote_k,
                "gamma": gamma,
                "use_warm_start": use_warm_start,
                "use_importance_mutation": use_importance_mutation,
                "use_coupled_mutation": use_coupled_mutation,
                "use_surrogate_promotion": use_surrogate_promotion,
                "multi_fidelity": multi_fidelity,
                "lf_eval_mode": lf_eval_mode,
                "real_eval_task": (real_eval_params or {}).get("eval_task") if lf_eval_mode == "real_task" else None,
            },
            "stats": {
                "gen_stats": gen_stats,
                "n_lf_total": n_lf_total,
                "n_hf_total": n_hf_total,
                "promotion_ratio": (n_hf_total / n_lf_total) if n_lf_total else 0.0,
            }
        }


# =============================
# Phase II: Local BO refinement
# =============================

def atomic_distance(q1: Sequence[int], r1: Sequence[int], q2: Sequence[int], r2: Sequence[int],
                    enc: ConfigEncoding) -> int:
    """Compute atomic edit distance d_atom(C, C') per Eq. 11.

    An atomic edit changes exactly one variable to an adjacent value on its ladder.
    Distance is the minimum number of atomic edits to transform C into C'.
    """
    dist = 0
    for i in range(len(q1)):
        # For q[i], count steps between values in sorted ladder
        if q1[i] != q2[i]:
            idx1 = enc.Q.index(q1[i])
            idx2 = enc.Q.index(q2[i])
            dist += abs(idx2 - idx1)
        # For r[i], count steps between values in sorted ladder
        if r1[i] != r2[i]:
            idx1 = enc.R.index(r1[i])
            idx2 = enc.R.index(r2[i])
            dist += abs(idx2 - idx1)
    return dist


def generate_k_nearest_atomic_neighbors(
    q: List[int],
    r: List[int],
    enc: ConfigEncoding,
    k: int = 5
) -> List[Tuple[List[int], List[int]]]:
    """Generate k nearest atomic neighbors (1-edit distance from current config).

    For each layer, try incrementing/decrementing q or r by one step.
    Return up to k neighbors (priority: most important layers first).
    """
    neighbors = []

    # Generate all 1-step neighbors
    for l in range(len(q)):
        # Try adjacent q values
        idx_q = enc.Q.index(q[l])
        for new_idx in [idx_q - 1, idx_q + 1]:
            if 0 <= new_idx < len(enc.Q):
                q_new = q.copy()
                q_new[l] = enc.Q[new_idx]
                neighbors.append((q_new, r.copy()))

        # Try adjacent r values
        idx_r = enc.R.index(r[l])
        for new_idx in [idx_r - 1, idx_r + 1]:
            if 0 <= new_idx < len(enc.R):
                r_new = r.copy()
                r_new[l] = enc.R[new_idx]
                neighbors.append((q.copy(), r_new))

    # Deduplicate by converting to tuples
    unique_neighbors = list(set((tuple(q_n), tuple(r_n)) for q_n, r_n in neighbors))
    unique_neighbors = [(list(q_n), list(r_n)) for q_n, r_n in unique_neighbors]

    # Return up to k neighbors (randomly sample if too many)
    if len(unique_neighbors) <= k:
        return unique_neighbors
    else:
        return random.sample(unique_neighbors, k)


def generate_atomic_neighbors(
    q: List[int],
    r: List[int],
    enc: ConfigEncoding,
    delta: int
) -> List[Tuple[List[int], List[int]]]:
    """Generate all configurations within atomic edit distance δ from (q, r).

    This exhaustively enumerates configurations by applying ≤ δ atomic edits.
    For practical search with large δ, this returns a sampled subset.
    """
    if delta == 0:
        return [(q.copy(), r.copy())]

    # For small delta, enumerate exactly; for large delta, sample
    if delta <= 2:
        # Exhaustive enumeration for small delta
        neighbors = set()
        neighbors.add((tuple(q), tuple(r)))

        # Apply single atomic edits
        for l in range(len(q)):
            # Try adjacent q values
            idx_q = enc.Q.index(q[l])
            for new_idx in [idx_q - 1, idx_q + 1]:
                if 0 <= new_idx < len(enc.Q):
                    q_new = q.copy()
                    q_new[l] = enc.Q[new_idx]
                    neighbors.add((tuple(q_new), tuple(r)))

            # Try adjacent r values
            idx_r = enc.R.index(r[l])
            for new_idx in [idx_r - 1, idx_r + 1]:
                if 0 <= new_idx < len(enc.R):
                    r_new = r.copy()
                    r_new[l] = enc.R[new_idx]
                    neighbors.add((tuple(q), tuple(r_new)))

        # For delta=2, apply two atomic edits
        if delta == 2:
            level1 = list(neighbors)
            for q1_tup, r1_tup in level1:
                q1, r1 = list(q1_tup), list(r1_tup)
                # Apply second edit
                for l in range(len(q1)):
                    idx_q = enc.Q.index(q1[l])
                    for new_idx in [idx_q - 1, idx_q + 1]:
                        if 0 <= new_idx < len(enc.Q):
                            q_new = q1.copy()
                            q_new[l] = enc.Q[new_idx]
                            neighbors.add((tuple(q_new), tuple(r1)))

                    idx_r = enc.R.index(r1[l])
                    for new_idx in [idx_r - 1, idx_r + 1]:
                        if 0 <= new_idx < len(enc.R):
                            r_new = r1.copy()
                            r_new[l] = enc.R[new_idx]
                            neighbors.add((tuple(q1), tuple(r_new)))

        return [(list(q_tup), list(r_tup)) for q_tup, r_tup in neighbors]
    else:
        # For large delta, sample a subset
        neighbors = [(q.copy(), r.copy())]
        num_samples = min(200, 2 ** delta)  # Cap at reasonable size

        for _ in range(num_samples):
            q_new, r_new = q.copy(), r.copy()
            budget = delta

            # Apply random atomic edits up to delta
            while budget > 0 and random.random() < 0.7:  # Probability to continue
                l = random.randrange(len(q_new))
                if random.random() < 0.5:  # Edit q
                    idx = enc.Q.index(q_new[l])
                    direction = random.choice([-1, 1])
                    new_idx = idx + direction
                    if 0 <= new_idx < len(enc.Q):
                        q_new[l] = enc.Q[new_idx]
                        budget -= 1
                else:  # Edit r
                    idx = enc.R.index(r_new[l])
                    direction = random.choice([-1, 1])
                    new_idx = idx + direction
                    if 0 <= new_idx < len(enc.R):
                        r_new[l] = enc.R[new_idx]
                        budget -= 1

            neighbors.append((q_new, r_new))

        return neighbors


class PhaseIIBO:
    """Phase II: Local Bayesian Optimization with Trust Region and Expected Improvement.

    Implements Eqs. 8-13 from paper:
    - Scalarized utility f(C; α) = α·P̂(C) - (1-α)·M̂(C)
    - Trust region Ω_t with atomic edit distance ≤ δ_t
    - Expected Improvement acquisition
    - Adaptive trust radius based on improvement
    - Convergence criterion based on max EI
    """

    def __init__(
        self,
        enc: ConfigEncoding,
        importance: Importance,
        mem_model: MemoryModel,
        budget_bytes: float,
        I_q: np.ndarray,
        I_r: np.ndarray,
        # Trust region parameters (Modified: k=5 neighbors, max_iter=3)
        k_neighbors: int = 5,  # Fixed 5 nearest neighbors
        # Convergence parameter
        epsilon_ei: float = 1e-4,
        max_iterations: int = 3,  # Modified: 3 iterations as requested
    ):
        """
        Args:
            enc: Configuration encoding
            importance: Importance profiles
            mem_model: Memory model for repair operations
            budget_bytes: Memory budget constraint
            I_q: Quantization sensitivity per layer
            I_r: Rank sensitivity per layer
            k_neighbors: Number of nearest neighbors to search (fixed at 5)
            epsilon_ei: Convergence threshold for max EI
            max_iterations: Maximum BO iterations (default: 3)
        """
        self.enc = enc
        self.I = importance.score
        self.mem = mem_model
        self.budget_bytes = budget_bytes
        self.I_q = I_q
        self.I_r = I_r

        # Fixed k-nearest neighbors instead of adaptive trust radius
        self.k_neighbors = k_neighbors

        # Convergence
        self.epsilon_ei = epsilon_ei
        self.max_iterations = max_iterations

        # GP with Matérn-5/2 kernel per Eq. 9
        kernel = ConstantKernel(1.0, (1e-2, 1e2)) * Matern(length_scale=1.0, nu=2.5) + WhiteKernel(1e-3)
        self.gp = GaussianProcessRegressor(kernel=kernel, normalize_y=True, random_state=0)

        # Training data
        self.X_train: List[np.ndarray] = []
        self.y_train: List[float] = []
        self.best_y = -float('inf')

    def encode(self, q: Sequence[int], r: Sequence[int]) -> np.ndarray:
        """Encode configuration as ψ(C) per Eq. 9.

        Uses ordinal embedding: each q_ℓ and r_ℓ mapped to ordinal index and standardized.
        """
        # Ordinal encoding normalized to [0, 1]
        qn = np.array([self.enc.s_q(x) for x in q], dtype=np.float64)
        rn = np.array([self.enc.s_r(x) for x in r], dtype=np.float64)

        # Importance-weighted coverage features
        q_cov = float(np.dot(self.I, qn) / len(qn))
        r_cov = float(np.dot(self.I, rn) / len(rn))

        # Concatenate: [q_0, ..., q_L-1, r_0, ..., r_L-1, q_cov, r_cov]
        return np.concatenate([qn, rn, [q_cov, r_cov]], axis=0)

    def scalarize(
        self,
        perf: float,
        mem: float,
        alpha: float,
        perf_norm: Tuple[float, float],
        mem_norm: Tuple[float, float]
    ) -> float:
        """Scalarized utility function f(C; α) per Eq. 8.

        f(C; α) = α · P̂(C) - (1-α) · M̂(C)

        where P̂ and M̂ are min-max normalized.
        """
        p_min, p_max = perf_norm
        m_min, m_max = mem_norm

        # Min-max normalization
        p_norm = (perf - p_min) / max(p_max - p_min, 1e-8)
        m_norm = (mem - m_min) / max(m_max - m_min, 1e-8)

        # Scalarize: maximize performance, minimize memory
        return alpha * p_norm - (1 - alpha) * m_norm

    def fit(self, pareto: List[Dict], alpha: float):
        """Warm-start GP from Phase I Pareto front with scalarized utility."""
        # Extract performance and memory from Pareto set
        perfs = np.array([p["phigh"] for p in pareto], dtype=np.float64)
        mems = np.array([p["mem"] for p in pareto], dtype=np.float64)

        # Compute normalization ranges
        perf_norm = (float(np.min(perfs)), float(np.max(perfs)))
        mem_norm = (float(np.min(mems)), float(np.max(mems)))

        # Store for later use
        self.perf_norm = perf_norm
        self.mem_norm = mem_norm
        self.alpha = alpha

        # Compute scalarized utilities
        for p in pareto:
            x = self.encode(p["q"], p["r"])
            y = self.scalarize(p["phigh"], p["mem"], alpha, perf_norm, mem_norm)
            self.X_train.append(x)
            self.y_train.append(y)
            if y > self.best_y:
                self.best_y = y

        # Fit GP
        if len(self.X_train) > 0:
            X = np.vstack(self.X_train)
            y = np.array(self.y_train)
            self.gp.fit(X, y)

    def expected_improvement(self, mu: np.ndarray, sigma: np.ndarray) -> np.ndarray:
        """Expected Improvement acquisition function per Eq. 10.

        EI_t(C) = σ_t(C) · [z_t(C)·Φ(z_t(C)) + φ(z_t(C))]
        where z_t(C) = (μ_t(C) - y_t^+) / σ_t(C)
        """
        from scipy.stats import norm

        ei = np.zeros_like(mu)
        mask = sigma > 1e-8

        if not np.any(mask):
            return ei

        # Improvement over best observed
        z = np.zeros_like(mu)
        z[mask] = (mu[mask] - self.best_y) / sigma[mask]

        # Expected improvement
        ei[mask] = sigma[mask] * (z[mask] * norm.cdf(z[mask]) + norm.pdf(z[mask]))

        return ei

    def build_trust_region(
        self,
        incumbent: Tuple[List[int], List[int]],
        k_neighbors: int = 5
    ) -> List[Tuple[List[int], List[int]]]:
        """Build trust region with k nearest atomic neighbors around a single point.

        Args:
            incumbent: Current best configuration (q, r)
            k_neighbors: Number of neighbors to generate (default: 5)

        Returns:
            List of at most k feasible neighbors after repair
        """
        q_inc, r_inc = incumbent

        # Generate k nearest atomic neighbors (1-edit distance)
        neighbors = generate_k_nearest_atomic_neighbors(q_inc, r_inc, self.enc, k=k_neighbors)

        # Apply repair to ensure feasibility
        feasible_neighbors = []
        for q, r in neighbors:
            q_rep, r_rep = repair_to_budget(
                q, r, self.enc, self.I_q, self.I_r, self.mem, self.budget_bytes
            )
            # Check if actually feasible
            if self.mem.total_memory_bytes(q_rep, r_rep) <= self.budget_bytes:
                feasible_neighbors.append((q_rep, r_rep))

        # Always include incumbent itself
        if (q_inc, r_inc) not in feasible_neighbors:
            feasible_neighbors.insert(0, (q_inc, r_inc))

        return feasible_neighbors if feasible_neighbors else [(q_inc, r_inc)]

    def build_multi_start_trust_region(
        self,
        pareto_points: List[Tuple[List[int], List[int]]],
        k_neighbors: int = 5
    ) -> List[Tuple[List[int], List[int]]]:
        """Build joint trust region from ALL Pareto points.

        NEW: Instead of building trust region around single incumbent,
        generate neighbors around ALL HF-evaluated Pareto points.
        This enables broader exploration while leveraging all Phase I results.

        Args:
            pareto_points: List of (q, r) configurations from Pareto front
            k_neighbors: Number of neighbors per point (default: 5)

        Returns:
            List of unique feasible neighbors from all Pareto points
        """
        all_neighbors = set()

        # Generate neighbors around each Pareto point
        for q_p, r_p in pareto_points:
            neighbors = generate_k_nearest_atomic_neighbors(q_p, r_p, self.enc, k=k_neighbors)

            for q, r in neighbors:
                q_rep, r_rep = repair_to_budget(
                    q, r, self.enc, self.I_q, self.I_r, self.mem, self.budget_bytes
                )
                # Check if actually feasible
                if self.mem.total_memory_bytes(q_rep, r_rep) <= self.budget_bytes:
                    all_neighbors.add((tuple(q_rep), tuple(r_rep)))

            # Include the Pareto point itself
            all_neighbors.add((tuple(q_p), tuple(r_p)))

        # Convert back to list format
        return [(list(q), list(r)) for q, r in all_neighbors]

    def propose(self, incumbent: Tuple[List[int], List[int]], k_neighbors: int = 5) -> Tuple[List[int], List[int], float]:
        """Propose next candidate by maximizing EI over k nearest neighbors."""
        # Build trust region with k nearest neighbors
        omega = self.build_trust_region(incumbent, k_neighbors=k_neighbors)

        # Encode all candidates
        X = np.vstack([self.encode(q, r) for q, r in omega])

        # Predict with GP
        mu, sigma = self.gp.predict(X, return_std=True)

        # Compute Expected Improvement
        ei = self.expected_improvement(mu, sigma)

        # Select candidate with maximum EI
        best_idx = int(np.argmax(ei))
        max_ei = float(ei[best_idx])

        return omega[best_idx][0], omega[best_idx][1], max_ei

    def propose_multi_start(
        self,
        pareto_points: List[Tuple[List[int], List[int]]],
        k_neighbors: int = 5
    ) -> Tuple[List[int], List[int], float, int]:
        """Propose next candidate from joint trust region of ALL Pareto points.

        NEW: Multi-start approach - builds trust region around all Phase I
        Pareto points (e.g., 3 points × 5 neighbors = 15 candidates), then
        selects the best by EI.

        Args:
            pareto_points: List of (q, r) from Pareto front
            k_neighbors: Neighbors per point (default: 5)

        Returns:
            (q, r, max_ei, num_candidates) - best candidate and search stats
        """
        # Build joint trust region from all Pareto points
        omega = self.build_multi_start_trust_region(pareto_points, k_neighbors=k_neighbors)

        if not omega:
            # Fallback: return first Pareto point
            q, r = pareto_points[0]
            return q, r, 0.0, 0

        # Encode all candidates
        X = np.vstack([self.encode(q, r) for q, r in omega])

        # Predict with GP
        mu, sigma = self.gp.predict(X, return_std=True)

        # Compute Expected Improvement
        ei = self.expected_improvement(mu, sigma)

        # Select candidate with maximum EI
        best_idx = int(np.argmax(ei))
        max_ei = float(ei[best_idx])

        return omega[best_idx][0], omega[best_idx][1], max_ei, len(omega)

    def update(self, q: List[int], r: List[int], perf: float, mem: float) -> bool:
        """Update GP with new observation (no trust radius adaptation - k is fixed).

        Returns True if improvement found, False otherwise.
        """
        # Encode and compute scalarized utility
        x = self.encode(q, r)
        y = self.scalarize(perf, mem, self.alpha, self.perf_norm, self.mem_norm)

        # Check for improvement
        improved = y > self.best_y

        if improved:
            self.best_y = y

        # Add to training data
        self.X_train.append(x)
        self.y_train.append(y)

        # Refit GP
        X = np.vstack(self.X_train)
        y_arr = np.array(self.y_train)
        self.gp.fit(X, y_arr)

        return improved

    def has_converged(self, max_ei: float) -> bool:
        """Check convergence criterion per Eq. 13.

        Converged when max_{C ∈ Ω_t} EI_t(C) < ε_ei
        """
        return max_ei < self.epsilon_ei


# =============================
# Orchestrator
# =============================

class AutoQRA:
    def __init__(
        self,
        cfg: AutoQRAConfig,
        importance_json: Path,
        real_eval_params: Optional[Dict] = None,
        lowbit_value: Optional[int] = None,
        max_lowbit_fraction: float = 1.0,
        target_avg_bits: float = None,
    ):
        self.cfg = cfg
        self.imp = Importance.from_json(importance_json, num_layers=cfg.num_layers)
        self.enc = ConfigEncoding(cfg.Q, cfg.R)
        self.mem = MemoryModel(cfg)
        self.real_eval_params = real_eval_params
        self.lowbit_value = lowbit_value
        self.max_lowbit_fraction = max_lowbit_fraction
        self.target_avg_bits = target_avg_bits

    def run(self, outdir: Path, phase1_kwargs: Dict, phase2_alpha: Optional[float] = None) -> Dict:
        outdir.mkdir(parents=True, exist_ok=True)

        # Phase I
        evo = PhaseIEvolution(
            self.cfg,
            self.imp,
            real_eval_params=self.real_eval_params,
            lowbit_value=self.lowbit_value,
            max_lowbit_fraction=self.max_lowbit_fraction,
            target_avg_bits=self.target_avg_bits,
        )
        res1 = evo.run(
            **phase1_kwargs,
            real_eval_params=self.real_eval_params,
        )
        pareto = res1["pareto"]
        lf_mode = phase1_kwargs.get("lf_eval_mode", "proxy")

        # Save Phase I
        with open(outdir / "phase1_pareto.json", "w") as f:
            json.dump({"pareto": pareto, "hv_hist": res1["hv_hist"], "config": res1["config"]}, f, indent=2)
        with open(outdir / "phase1_stats.json", "w") as f:
            json.dump(res1["stats"], f, indent=2)
        # Save all evaluated points if available
        try:
            with open(outdir / "phase1_all.json", "w") as f:
                json.dump(res1["all"], f, indent=2)
        except Exception:
            pass
        # Final figures
        try:
            fig_dir = outdir / "figures"
            fig_dir.mkdir(parents=True, exist_ok=True)
            hv_hist = res1["hv_hist"]
            gen_stats = res1["stats"].get("gen_stats", [])
            sns.set_style("whitegrid")
            plt.figure(figsize=(6, 4))
            plt.plot(np.arange(len(hv_hist)), hv_hist, marker='o')
            plt.xlabel("Generation"); plt.ylabel("Pareto Hypervolume")
            plt.tight_layout(); plt.savefig(fig_dir / "hv_curve.png", dpi=200); plt.close()
            if gen_stats:
                gens = [g["generation"] for g in gen_stats]
                nlf = [g["n_lf"] for g in gen_stats]
                nhf = [g["n_hf"] for g in gen_stats]
                pr  = [g.get("promotion_ratio", 0.0) for g in gen_stats]
                fig, ax1 = plt.subplots(1, 1, figsize=(6, 4))
                ax1.plot(gens, nlf, label='LF', color='#2E86AB')
                ax1.plot(gens, nhf, label='HF', color='#A23B72')
                ax1.set_xlabel("Generation"); ax1.set_ylabel("#Evaluations")
                ax2 = ax1.twinx(); ax2.plot(gens, pr, label='Promotion Ratio', color='#33AA33')
                ax2.set_ylabel("Promotion Ratio")
                lines, labels = ax1.get_legend_handles_labels(); l2, lb2 = ax2.get_legend_handles_labels()
                ax1.legend(lines + l2, labels + lb2, loc='upper right', fontsize=8)
                plt.tight_layout(); plt.savefig(fig_dir / "eval_counts.png", dpi=200); plt.close()
        except Exception:
            pass

        # Phase II (optional)
        best = None
        phase2_history = []
        if phase2_alpha is not None and len(pareto) >= 3:
            if lf_mode == "real_task":
                print("Skipping Phase II: real-task evaluation already supplies high-fidelity metrics.")
                return {"phase1": res1, "phase2_best": None, "phase2_history": []}

            print(f"\n=== 阶段二：贝叶斯精炼 (α={phase2_alpha}) ===")

            # Initialize Phase II BO with new interface
            bo = PhaseIIBO(
                enc=self.enc,
                importance=self.imp,
                mem_model=self.mem,
                budget_bytes=self.cfg.budget_bytes or float('inf'),
                I_q=self.imp.I_q,
                I_r=self.imp.I_r,
                k_neighbors=5,  # 固定5个最近邻居
                epsilon_ei=1e-4,
                max_iterations=3,  # 3次迭代
            )

            # Warm-start GP from Phase I Pareto front
            bo.fit(pareto, alpha=phase2_alpha)

            # Select incumbent: best scalarized utility from Pareto front
            incumbent_rec = max(
                pareto,
                key=lambda p: bo.scalarize(p["phigh"], p["mem"], phase2_alpha, bo.perf_norm, bo.mem_norm)
            )
            incumbent = (incumbent_rec["q"], incumbent_rec["r"])

            print(f"起始配置: 性能={incumbent_rec['phigh']:.4f}, 内存={incumbent_rec['mem']:.2e}")

            # Build initial Pareto points list for multi-start trust region
            pareto_configs = [(p["q"], p["r"]) for p in pareto]

            print(f"使用多起点Trust Region: {len(pareto_configs)}个Pareto点 × {bo.k_neighbors}邻居")

            # Iterative BO loop with multi-start trust region
            for t in range(bo.max_iterations):
                # NEW: Propose from joint trust region of ALL Pareto points
                q_new, r_new, max_ei, num_candidates = bo.propose_multi_start(
                    pareto_configs, k_neighbors=bo.k_neighbors
                )

                print(f"\n第 {t+1}/{bo.max_iterations} 次迭代:")
                print(f"  搜索 {num_candidates} 个候选 (来自{len(pareto_configs)}个Pareto点)")
                print(f"  最大 EI={max_ei:.6f}")

                # Check convergence (Eq. 13)
                if bo.has_converged(max_ei):
                    print(f"  收敛：最大 EI < {bo.epsilon_ei}")
                    break

                # Evaluate new candidate with high-fidelity
                if lf_mode == "ptq" and self.cfg.base_model_id:
                    # Re-create evaluator for PTQ mode
                    try:
                        from autoqra.mixed_quant_eval import MQEvalConfig, MixedQuantEvaluator
                    except ImportError:
                        import sys
                        autoqra_dir = Path(__file__).parent
                        if str(autoqra_dir) not in sys.path:
                            sys.path.insert(0, str(autoqra_dir))
                        from mixed_quant_eval import MQEvalConfig, MixedQuantEvaluator
                    mq_eval = MixedQuantEvaluator(MQEvalConfig(
                        model_id=self.cfg.base_model_id,
                        task=self.cfg.task,
                        sample_ratio_lf=0.02,
                        sample_ratio_hf=0.1,
                        per_channel=True,
                    ))
                    metric_hf = mq_eval.evaluate_candidate(q_new, high_fidelity=True)
                    if self.cfg.task == "generation":
                        phigh = -metric_hf
                    else:
                        phigh = metric_hf
                    M = self.mem.total_memory_bytes(q_new, r_new)
                else:
                    # Fallback to proxy for proxy mode
                    proxy = ProxyEvaluator(self.cfg, self.imp, self.mem, self.enc)
                    plow, M = proxy.evaluate(q_new, r_new, self.cfg.budget_bytes)
                    phigh = float(plow + np.random.normal(scale=0.02))

                # Compute scalarized utility
                utility = bo.scalarize(phigh, M, phase2_alpha, bo.perf_norm, bo.mem_norm)

                print(f"  评估结果: 性能={phigh:.4f}, 内存={M:.2e}, 效用={utility:.4f}")

                # Update GP with new observation
                improved = bo.update(q_new, r_new, phigh, M)

                if improved:
                    print(f"  ✓ 发现改进！新的最佳效用: {bo.best_y:.4f}")
                    incumbent = (q_new, r_new)
                    incumbent_rec = {"q": list(q_new), "r": list(r_new), "phigh": phigh, "mem": M}
                    # Add improved point to Pareto configs for next iteration
                    pareto_configs.append((list(q_new), list(r_new)))
                else:
                    print(f"  ✗ 无改进 (最佳: {bo.best_y:.4f})")

                # Record iteration
                phase2_history.append({
                    "iteration": t + 1,
                    "q": list(q_new),
                    "r": list(r_new),
                    "phigh": phigh,
                    "mem": M,
                    "utility": utility,
                    "max_ei": max_ei,
                    "num_candidates": num_candidates,
                    "num_pareto_points": len(pareto_configs),
                    "improved": improved,
                })

            # Final selected configuration
            best = {
                "q": list(incumbent[0]),
                "r": list(incumbent[1]),
                "phigh": incumbent_rec["phigh"],
                "mem": incumbent_rec["mem"],
                "utility": bo.best_y,
                "alpha": phase2_alpha,
            }

            print(f"\n=== 阶段二完成 ===")
            print(f"最佳配置: 性能={best['phigh']:.4f}, 内存={best['mem']:.2e}, 效用={best['utility']:.4f}")

            with open(outdir / "phase2_selected.json", "w") as f:
                json.dump(best, f, indent=2)
            with open(outdir / "phase2_history.json", "w") as f:
                json.dump(phase2_history, f, indent=2)

        return {"phase1": res1, "phase2_best": best, "phase2_history": phase2_history}


# =============================
# CLI
# =============================

def main():
    p = argparse.ArgumentParser(description="AutoQRA: joint precision+rank allocation under memory budget")
    p.add_argument("--num_layers", type=int, default=28)
    p.add_argument("--bits", type=int, nargs="+", default=[2, 3, 4, 6, 8])
    p.add_argument("--ranks", type=int, nargs="+", default=[4, 6, 8, 10, 12, 16])  # Updated: no 32, added 6, 10, 12
    p.add_argument("--lora_precision_bits", type=int, default=16)
    p.add_argument("--budget_bytes", type=float, default=None)
    p.add_argument("--layer_params_json", type=str, default=None, help="JSON with keys layer_param_counts, lora_params_per_rank")
    p.add_argument("--importance_json", type=str, required=True)
    p.add_argument("--seed", type=int, default=42)
    # mixed-quant evaluation options
    p.add_argument("--lf_eval_mode", type=str, choices=["proxy", "ptq", "real_task"], default="proxy",
                   help="Evaluation mode: proxy (importance coverage), ptq (mixed quant), or real_task (full SFT+lm-eval)")
    p.add_argument("--preset", type=str, default=None, help="Optional preset to resolve base model id for ptq evaluator")
    p.add_argument("--base_model_id", type=str, default=None, help="Explicit base model id for ptq evaluator")
    p.add_argument("--task", type=str, choices=["generation", "classification"], default="generation")
    p.add_argument("--lowbit_value", type=int, default=None,
                   help="Explicit low bit value to limit (default: min of --bits)")
    p.add_argument("--max_lowbit_fraction", type=float, default=0.5,
                   help="Maximum fraction of layers allowed to use the lowest bit (0-1]. Default 0.5")
    p.add_argument("--target_avg_bits", type=float, default=None,
                   help="Target average bits for proxy evaluation (e.g., 6.0). Adds bonus for matching target.")
    # Real-task evaluation parameters
    p.add_argument("--real_eval_task", type=str,
                   choices=["winogrande", "arc_challenge", "arc_easy", "boolq", "gsm8k", "hellaswag", "openbookqa", "piqa"],
                   default="winogrande",
                   help="Downstream lm-eval task to optimize when lf_eval_mode=real_task")
    p.add_argument("--real_eval_dataset", type=str, choices=["alpaca", "hc3"], default="alpaca",
                   help="SFT dataset used during real-task evaluation (train_autoqra_sft.py)")
    # Low-Fidelity (LF) parameters - fast, cheap evaluation
    p.add_argument("--real_eval_lf_sample_ratio", type=float, default=0.10,
                   help="LF: Sample ratio for fast SFT (default 0.10 = 10%%, per paper spec)")
    p.add_argument("--real_eval_lf_epochs", type=float, default=0.2,
                   help="LF: Epochs for fast SFT (default 0.2, per paper spec)")
    # High-Fidelity (HF) parameters - thorough, expensive evaluation
    p.add_argument("--real_eval_hf_sample_ratio", type=float, default=1.0,
                   help="HF: Sample ratio for thorough SFT (default 1.0 = 100%%, per paper spec)")
    p.add_argument("--real_eval_hf_epochs", type=float, default=1.0,
                   help="HF: Epochs for thorough SFT (default 1.0)")
    # Legacy parameters (kept for backward compatibility)
    p.add_argument("--real_eval_sample_ratio", type=float, default=None,
                   help="[DEPRECATED] Use --real_eval_hf_sample_ratio instead")
    p.add_argument("--real_eval_epochs", type=float, default=None,
                   help="[DEPRECATED] Use --real_eval_hf_epochs instead")
    # Other real-task parameters
    p.add_argument("--real_eval_shots", type=int, default=0,
                   help="Few-shot setting for lm-eval harness when lf_eval_mode=real_task")
    p.add_argument("--real_eval_output", type=str, default="./qwen_lora_importance/results_real_eval",
                   help="Directory to cache real-task evaluation artifacts")
    p.add_argument("--real_eval_no_4bit", action="store_true",
                   help="Disable --load_in_4bit during the quick SFT stage (defaults to enabled)")
    # Phase I
    p.add_argument("--phase1_pop", type=int, default=40)
    p.add_argument("--phase1_generations", type=int, default=12)
    p.add_argument("--phase1_promote", type=int, default=6)
    p.add_argument("--phase1_gamma", type=float, default=1.5)
    # Ablations
    p.add_argument("--no_warm_start", action="store_true")
    p.add_argument("--no_importance_mutation", action="store_true")
    p.add_argument("--no_coupled_mutation", action="store_true")
    p.add_argument("--no_surrogate_promotion", action="store_true")
    p.add_argument("--hf_all", action="store_true", help="Disable multi-fidelity; evaluate HF for all offspring (expensive)")
    # Phase II
    p.add_argument("--phase2_alpha", type=float, default=None)
    # Output
    p.add_argument("--outdir", type=str, default="./qwen_lora_importance/results_autoqra")
    args = p.parse_args()

    # Load layer meta if provided
    layer_param_bytes = None
    lora_params_per_rank = None
    if args.layer_params_json:
        with open(args.layer_params_json, "r") as f:
            meta = json.load(f)
            # layer_param_counts means number of parameters; we store param count and multiply by q/8 for bytes
            layer_param_bytes = meta.get("layer_param_counts")
            lora_params_per_rank = meta.get("lora_params_per_rank")

    # If importance JSON exists, infer number of layers from it unless overridden
    inferred_layers = None
    try:
        with open(args.importance_json, "r") as f:
            j = json.load(f)
            bdict = j.get("backbone_metric_per_layer", {})
            inferred_layers = max(len(bdict), 0)
    except Exception:
        pass

    # resolve base model id when lf_eval_mode is ptq
    base_model_id = args.base_model_id
    if args.lf_eval_mode == "ptq":
        if not base_model_id and args.preset:
            try:
                from ..model_presets import get_preset as _get_preset
            except Exception:
                from qwen_lora_importance.model_presets import get_preset as _get_preset
            base_model_id = _get_preset(args.preset).hf_id

    real_eval_params = None
    if args.lf_eval_mode == "real_task":
        if not args.preset:
            raise ValueError("lf_eval_mode=real_task requires --preset to load the base model.")

        # Handle backward compatibility: if legacy params provided, use them for HF
        hf_sample_ratio = args.real_eval_hf_sample_ratio
        hf_epochs = args.real_eval_hf_epochs
        if args.real_eval_sample_ratio is not None:
            hf_sample_ratio = args.real_eval_sample_ratio
        if args.real_eval_epochs is not None:
            hf_epochs = args.real_eval_epochs

        real_eval_params = {
            "preset": args.preset,
            "dataset": args.real_eval_dataset,
            "lf_sample_ratio": args.real_eval_lf_sample_ratio,
            "lf_epochs": args.real_eval_lf_epochs,
            "hf_sample_ratio": hf_sample_ratio,
            "hf_epochs": hf_epochs,
            # Keep legacy params for backward compatibility (will be ignored by RealTaskEvaluator)
            "sample_ratio": hf_sample_ratio,
            "epochs": hf_epochs,
            "eval_task": args.real_eval_task,
            "eval_shots": args.real_eval_shots,
            "output_root": args.real_eval_output,
            "load_in_4bit": not args.real_eval_no_4bit,
        }

    cfg = AutoQRAConfig(
        num_layers=inferred_layers or args.num_layers,
        Q=args.bits,
        R=args.ranks,
        lora_precision_bits=args.lora_precision_bits,
        layer_param_bytes=layer_param_bytes,
        lora_params_per_rank=lora_params_per_rank,
        seed=args.seed,
        budget_bytes=args.budget_bytes,
        base_model_id=base_model_id,
        task=args.task,
    )

    runner = AutoQRA(
        cfg,
        importance_json=Path(args.importance_json),
        real_eval_params=real_eval_params,
        lowbit_value=args.lowbit_value,
        max_lowbit_fraction=args.max_lowbit_fraction,
        target_avg_bits=args.target_avg_bits,
    )
    out = runner.run(
        outdir=Path(args.outdir),
        phase1_kwargs=dict(
            pop_size=args.phase1_pop,
            generations=args.phase1_generations,
            promote_k=args.phase1_promote,
            gamma=args.phase1_gamma,
            use_warm_start=not args.no_warm_start,
            use_importance_mutation=not args.no_importance_mutation,
            use_coupled_mutation=not args.no_coupled_mutation,
            use_surrogate_promotion=not args.no_surrogate_promotion,
            multi_fidelity=not args.hf_all,
            lf_eval_mode=args.lf_eval_mode,
        ),
        phase2_alpha=args.phase2_alpha,
    )

    print(f"Saved AutoQRA results to {args.outdir}")


if __name__ == "__main__":
    main()
