import os
import sys
from dataclasses import dataclass, field, replace
from typing import Dict, List, Optional, Tuple
import torch

sys.path.append(os.path.abspath(os.path.dirname(__file__)))

from ac_function.ac_function import gamma_schedule, maximize_acquisition_adam_multistart
from data.task_manager import TaskManager
from models.gp_model import GPModel
from models.mixture import ModelMixture
from models.deep_ensemble import DeepEnsemble


@dataclass
class EnsembleTrainConfig:
    """
    DeepEnsemble training config for history models.
    """

    hidden_dims: List[int] = field(default_factory=lambda: [128, 128])
    activation: str = "gelu"
    num_models: int = 5

    steps: int = 200
    lr: float = 1e-3
    weight_decay: float = 0.0
    batch_size: int = 64

    loss_type: str = "mse"
    list_size: Optional[int] = None
    lists_per_step: Optional[int] = None

    use_amp: bool = True
    log_every: int = 100

    def __post_init__(self) -> None:
        if self.num_models <= 0:
            raise ValueError("num_models must be positive")
        if self.steps <= 0:
            raise ValueError("steps must be positive")
        if self.lr <= 0:
            raise ValueError("lr must be positive")
        if self.weight_decay < 0:
            raise ValueError("weight_decay must be >= 0")
        if self.batch_size <= 0:
            raise ValueError("batch_size must be positive")
        if self.loss_type.lower() != "mse":
            if self.list_size is None or self.lists_per_step is None:
                raise ValueError(
                    "list_size and lists_per_step are required for list-wise losses"
                )


@dataclass
class HistoryTaskConfig:
    """
    History task config.
    """

    n_data: int = 256
    value_model: EnsembleTrainConfig = field(
        default_factory=lambda: EnsembleTrainConfig(loss_type="mse")
    )
    rank_model: EnsembleTrainConfig = field(
        default_factory=lambda: EnsembleTrainConfig(
            loss_type="listnet", list_size=32, lists_per_step=8
        )
    )

    def __post_init__(self) -> None:
        if self.n_data <= 0:
            raise ValueError("n_data must be positive")


def _fmt_train_cfg(cfg: EnsembleTrainConfig) -> str:
    parts = [
        f"hidden_dims={cfg.hidden_dims}",
        f"activation={cfg.activation}",
        f"num_models={cfg.num_models}",
        f"steps={cfg.steps}",
        f"lr={cfg.lr}",
        f"weight_decay={cfg.weight_decay}",
        f"batch_size={cfg.batch_size}",
        f"loss_type={cfg.loss_type}",
    ]
    if cfg.loss_type.lower() != "mse":
        parts.append(f"list_size={cfg.list_size}")
        parts.append(f"lists_per_step={cfg.lists_per_step}")
    parts.append(f"use_amp={cfg.use_amp}")
    parts.append(f"log_every={cfg.log_every}")
    return ", ".join(parts)


@dataclass
class TargetGPConfig:
    """
    Target GP training config.
    """

    n_iter: int = 300
    lr: float = 0.05
    n_restarts: int = 2

    def __post_init__(self) -> None:
        if self.n_iter <= 0:
            raise ValueError("n_iter must be positive")
        if self.lr <= 0:
            raise ValueError("lr must be positive")
        if self.n_restarts <= 0:
            raise ValueError("n_restarts must be positive")


@dataclass
class AcquisitionOptimizerConfig:
    """
    Gradient-ascent optimizer config for acquisition.
    """

    steps: int = 100
    lr: float = 0.05
    print_every: int = 50
    n_restarts: int = 16

    def __post_init__(self) -> None:
        if self.steps <= 0:
            raise ValueError("steps must be positive")
        if self.lr <= 0:
            raise ValueError("lr must be positive")
        if self.print_every <= 0:
            raise ValueError("print_every must be positive")
        if self.n_restarts <= 0:
            raise ValueError("n_restarts must be positive")


@dataclass
class AnchorConfig:
    """
    Anchor selection config.
    """

    history_topk_per_task: int = 1

    def __post_init__(self) -> None:
        if self.history_topk_per_task <= 0:
            raise ValueError("history_topk_per_task must be positive")


@dataclass
class AcquisitionConfig:
    """
    Acquisition config.
    """

    beta_t: float = 2.0

    gamma_mode: str = "schedule"  # "schedule" or "constant"
    gamma0: float = 1.0
    gamma: float = 1.0

    optimizer: AcquisitionOptimizerConfig = field(
        default_factory=AcquisitionOptimizerConfig
    )
    anchors: AnchorConfig = field(default_factory=AnchorConfig)

    def __post_init__(self) -> None:
        if self.beta_t < 0:
            raise ValueError("beta_t must be >= 0")
        if self.gamma_mode not in ("schedule", "constant"):
            raise ValueError("gamma_mode must be 'schedule' or 'constant'")
        if self.gamma_mode == "schedule" and self.gamma0 < 0:
            raise ValueError("gamma0 must be >= 0")
        if self.gamma_mode == "constant" and self.gamma < 0:
            raise ValueError("gamma must be >= 0")


@dataclass
class OptimizerConfig:
    """
    Optimizer config.
    """

    # -------------------
    # Search space
    # -------------------
    dim: int = 6
    bounds: Tuple[float, float] = (0.0, 1.0)
    raw_bounds: Optional[List[Tuple[float, float]]] = None
    normalize_y: bool = True

    # Init and iterations
    n_init: int = 8
    n_iter: int = 20
    obs_noise_std: float = 0.0

    # -------------------
    # Sampling design
    # -------------------
    # "lhs": Latin Hypercube Sampling
    # "uniform": Uniform random
    design: str = "lhs"

    # -------------------
    # History tasks
    # -------------------
    # Default history data size when no external datasets are provided
    default_history_n_data: int = 200
    # Per-task overrides
    history_tasks: Dict[str, HistoryTaskConfig] = field(default_factory=dict)

    # -------------------
    # Target GP and acquisition
    # -------------------
    target_gp: TargetGPConfig = field(default_factory=TargetGPConfig)
    acq: AcquisitionConfig = field(default_factory=AcquisitionConfig)
    calibration_size: int = 256

    # -------------------
    # Randomness
    # -------------------
    seed: int = 0

    def __post_init__(self) -> None:
        if self.dim <= 0:
            raise ValueError("dim must be positive")
        low, high = float(self.bounds[0]), float(self.bounds[1])
        if not low < high:
            raise ValueError("bounds must satisfy low < high")
        if self.raw_bounds is not None:
            if len(self.raw_bounds) != int(self.dim):
                raise ValueError("raw_bounds must have length == dim")
            for i, (l_i, h_i) in enumerate(self.raw_bounds):
                l_i = float(l_i)
                h_i = float(h_i)
                if not l_i < h_i:
                    raise ValueError(f"raw_bounds[{i}] must satisfy low < high")
        if self.n_init <= 0 or self.n_iter <= 0:
            raise ValueError("n_init/n_iter must be positive")
        if self.design not in ("lhs", "uniform"):
            raise ValueError("design must be 'lhs' or 'uniform'")
        if self.default_history_n_data <= 0:
            raise ValueError("default_history_n_data must be positive")
        if self.calibration_size <= 0:
            raise ValueError("calibration_size must be positive")


class TransferRankBayesOpt:
    """
    BO loop with transfer learning and Kendall weighting.
    """

    def __init__(
        self,
        history_tasks,
        target_task,
        config: OptimizerConfig,
        device: Optional[torch.device] = None,
        history_datasets: Optional[Dict[str, Tuple[torch.Tensor, torch.Tensor]]] = None,
        initial_target_X: Optional[torch.Tensor] = None,
    ):
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.cfg = config

        self.history_tasks = list(history_tasks)
        self.target_task = target_task
        self._provided_history_datasets = history_datasets
        self.initial_target_X = initial_target_X

        self.manager = TaskManager(device=self.device, dtype=torch.float32)

        self._rng = torch.Generator(device="cpu").manual_seed(int(self.cfg.seed))
        self._eval_rng = torch.Generator(device="cpu").manual_seed(int(self.cfg.seed))
        self._x_norm_low = float(self.cfg.bounds[0])
        self._x_norm_high = float(self.cfg.bounds[1])
        self._init_target_raw: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
        self._target_y_raw: Optional[torch.Tensor] = None
        self._history_y_raw: Dict[str, torch.Tensor] = {}

        if self.cfg.raw_bounds is None:
            raw_lows = torch.full(
                (int(self.cfg.dim),), float(self.cfg.bounds[0]), dtype=torch.float32
            )
            raw_highs = torch.full(
                (int(self.cfg.dim),), float(self.cfg.bounds[1]), dtype=torch.float32
            )
        else:
            raw_lows = torch.tensor(
                [float(b[0]) for b in self.cfg.raw_bounds], dtype=torch.float32
            )
            raw_highs = torch.tensor(
                [float(b[1]) for b in self.cfg.raw_bounds], dtype=torch.float32
            )

        self._x_raw_low = raw_lows.to(self.device)
        self._x_raw_high = raw_highs.to(self.device)

    def _normalize_x(self, x_raw: torch.Tensor) -> torch.Tensor:
        if not isinstance(x_raw, torch.Tensor):
            x_raw = torch.as_tensor(x_raw, device=self.device, dtype=torch.float32)
        x_raw = x_raw.to(self.device, dtype=torch.float32)
        a = float(self._x_norm_low)
        b = float(self._x_norm_high)
        denom = self._x_raw_high - self._x_raw_low
        scale = (b - a) / denom
        return a + (x_raw - self._x_raw_low) * scale

    def _denormalize_x(self, x_norm: torch.Tensor) -> torch.Tensor:
        if not isinstance(x_norm, torch.Tensor):
            x_norm = torch.as_tensor(x_norm, device=self.device, dtype=torch.float32)
        x_norm = x_norm.to(self.device, dtype=torch.float32)
        a = float(self._x_norm_low)
        b = float(self._x_norm_high)
        denom = b - a
        scale = (self._x_raw_high - self._x_raw_low) / denom
        return self._x_raw_low + (x_norm - a) * scale

    def _zscore_y_2d(self, y: torch.Tensor) -> torch.Tensor:
        y2d = self._ensure_2d_y(y).to(self.device, dtype=torch.float32)
        y1d = y2d.view(-1)
        mu = y1d.mean()
        sigma = y1d.std(unbiased=True)
        sigma = torch.clamp(sigma, min=1e-12)
        return ((y1d - mu) / sigma).view(-1, 1)

    def _evaluate_target_raw(self, X_raw: torch.Tensor) -> torch.Tensor:
        if not isinstance(X_raw, torch.Tensor):
            X_raw = torch.as_tensor(X_raw, dtype=torch.float32)
        X_cpu = X_raw.detach().to("cpu", dtype=torch.float32)
        y = self.target_task.objective(X_cpu).view(-1, 1)
        if float(self.cfg.obs_noise_std) > 0:
            eps = torch.randn(
                y.shape, generator=self._eval_rng, device=y.device, dtype=y.dtype
            )
            y = y + float(self.cfg.obs_noise_std) * eps
        return y.to(self.device, dtype=torch.float32)

    def _sample_uniform(self, n: int) -> torch.Tensor:
        """
        Uniform sampling in bounds.
        """
        low, high = self.cfg.bounds
        X = low + (high - low) * torch.rand((n, self.cfg.dim), generator=self._rng)
        return X.to(self.device, dtype=torch.float32)

    def _sample_latin_hypercube(self, n: int) -> torch.Tensor:
        """
        Latin hypercube sampling.
        """
        if n <= 0:
            raise ValueError("n must be positive")

        dim = int(self.cfg.dim)
        u = torch.rand((n, dim), generator=self._rng)
        perms = torch.stack(
            [torch.randperm(n, generator=self._rng) for _ in range(dim)], dim=1
        ).to(u.dtype)
        X01 = (perms + u) / float(n)

        low, high = float(self.cfg.bounds[0]), float(self.cfg.bounds[1])
        X = low + (high - low) * X01
        return X.to(self.device, dtype=torch.float32)

    def _sample_design(self, n: int) -> torch.Tensor:
        """
        Sampling dispatcher.
        """
        if self.cfg.design == "lhs":
            return self._sample_latin_hypercube(n)
        return self._sample_uniform(n)

    def _ensure_2d_y(self, y: torch.Tensor) -> torch.Tensor:
        """
        Ensure y has shape [N, 1].
        """
        if y.dim() == 1:
            return y.view(-1, 1)
        return y

    def build_history_datasets(self) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Build history datasets.
        """
        datasets: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
        for task in self.history_tasks:
            name = task.name

            # Use provided history data or sample with LHS.
            if (
                self._provided_history_datasets is not None
                and name in self._provided_history_datasets
            ):
                X_in, y_in = self._provided_history_datasets[name]
                X_raw = torch.as_tensor(X_in, device=self.device, dtype=torch.float32)
                X = self._normalize_x(X_raw)
                y = torch.as_tensor(y_in, device=self.device, dtype=torch.float32)
                y = self._ensure_2d_y(y)
                print(
                    f"[data][history] task={name} source=provided n={int(X_raw.size(0))}"
                )
            else:
                task_cfg = self.cfg.history_tasks.get(name)
                n_data = (
                    int(task_cfg.n_data)
                    if task_cfg is not None
                    else int(self.cfg.default_history_n_data)
                )
                X = self._sample_design(n_data)
                X_raw = self._denormalize_x(X)
                y = task.evaluate(X_raw, noise_std=0.0).to(self.device)
                cfg_source = "custom" if task_cfg is not None else "default"
                print(
                    f"[data][history] task={name} source=sampled cfg_source={cfg_source} n={int(X_raw.size(0))}"
                )

            y = self._ensure_2d_y(y)
            self._history_y_raw[name] = y.detach().clone()
            if self.cfg.normalize_y:
                y_fit = self._zscore_y_2d(y)
            else:
                y_fit = y

            datasets[task.name] = (X, y_fit)
            self.manager.add_history_task(task.name, X, y_fit)
        return datasets

    def fit_history_models(
        self, history_datasets: Dict[str, Tuple[torch.Tensor, torch.Tensor]]
    ) -> None:
        """
        Train value and rank models for history tasks.
        """
        for task in self.history_tasks:
            name = task.name
            X, y = history_datasets[name]
            y = self._ensure_2d_y(y)
            X_cpu = X.detach().cpu()
            y_cpu = y.detach().cpu()

            has_custom_cfg = name in self.cfg.history_tasks
            task_cfg = self.cfg.history_tasks.get(
                name, HistoryTaskConfig(n_data=self.cfg.default_history_n_data)
            )
            value_cfg = task_cfg.value_model
            rank_cfg = task_cfg.rank_model

            print(
                f"[fit][history] task={name} model=value cfg=({_fmt_train_cfg(value_cfg)}) "
                f"cfg_source={'custom' if has_custom_cfg else 'default'} normalize_y={self.cfg.normalize_y} n={int(X_cpu.size(0))}"
            )
            mse_model = DeepEnsemble(
                input_dim=self.cfg.dim,
                hidden_dims=value_cfg.hidden_dims,
                activation=value_cfg.activation,
                num_models=value_cfg.num_models,
                out_dim=1,
                device=self.device,
                dtype=torch.float32,
            )
            mse_model.fit(
                X_cpu,
                y_cpu,
                steps=value_cfg.steps,
                lr=value_cfg.lr,
                batch_size=value_cfg.batch_size,
                weight_decay=value_cfg.weight_decay,
                loss_type=value_cfg.loss_type,
                list_size=value_cfg.list_size,
                lists_per_step=value_cfg.lists_per_step,
                use_amp=value_cfg.use_amp,
                log_every=value_cfg.log_every,
            )

            print(
                f"[fit][history] task={name} model=rank cfg=({_fmt_train_cfg(rank_cfg)}) "
                f"cfg_source={'custom' if has_custom_cfg else 'default'} normalize_y={self.cfg.normalize_y} n={int(X_cpu.size(0))}"
            )
            rank_model = DeepEnsemble(
                input_dim=self.cfg.dim,
                hidden_dims=rank_cfg.hidden_dims,
                activation=rank_cfg.activation,
                num_models=rank_cfg.num_models,
                out_dim=1,
                device=self.device,
                dtype=torch.float32,
            )
            rank_model.fit(
                X_cpu,
                y_cpu,
                steps=rank_cfg.steps,
                lr=rank_cfg.lr,
                batch_size=rank_cfg.batch_size,
                weight_decay=rank_cfg.weight_decay,
                loss_type=rank_cfg.loss_type,
                list_size=rank_cfg.list_size,
                lists_per_step=rank_cfg.lists_per_step,
                use_amp=rank_cfg.use_amp,
                log_every=rank_cfg.log_every,
            )

            # Store models for Kendall weights and mixture experts.
            self.manager.set_history_models(name, mse_model, rank_model)

    def init_target(self) -> None:
        """
        Initialize target samples.
        """
        if self.initial_target_X is not None:
            print(
                f"[init] Using provided initial target X (n={len(self.initial_target_X)})"
            )
            # Use provided raw X.
            X0_raw = self.initial_target_X.to(self.device, dtype=torch.float32)
            # Normalize for model training.
            X0 = self._normalize_x(X0_raw)
        else:
            # Sample in normalized space.
            X0 = self._sample_design(self.cfg.n_init)
            # Convert to raw space for evaluation.
            X0_raw = self._denormalize_x(X0)

        # Evaluate in raw space.
        y0 = self._evaluate_target_raw(X0_raw)
        y0 = self._ensure_2d_y(y0)
        self._init_target_raw = (X0_raw.detach().clone(), y0.detach().clone())
        self._target_y_raw = y0.detach().clone()
        if self.cfg.normalize_y:
            y0_fit = self._zscore_y_2d(y0)
        else:
            y0_fit = y0
        self.manager.set_target_data(X0, y0_fit)

    def generate_or_load_data(self, pkl_path: str, load_from_pkl: bool = True) -> None:
        """
        Load or generate history data and target init points.
        """
        if load_from_pkl:
            if os.path.exists(pkl_path):
                print(f"Loading data from {pkl_path}")
                data = torch.load(pkl_path, map_location=self.device)
                self._provided_history_datasets = data["history"]
                self.initial_target_X = data["target_init_X"]
            else:
                raise FileNotFoundError(
                    f"Requested to load data from {pkl_path}, but file does not exist."
                )
        else:
            print(f"Generating data and saving to {pkl_path} (Overwrite if exists)")
            # Generate history data.
            raw_history = {}
            for task in self.history_tasks:
                name = task.name
                task_cfg = self.cfg.history_tasks.get(name)
                n_data = (
                    int(task_cfg.n_data)
                    if task_cfg is not None
                    else int(self.cfg.default_history_n_data)
                )

                # Sample normalized, then convert to raw.
                X = self._sample_design(n_data)
                X_raw = self._denormalize_x(X)
                y = task.evaluate(X_raw, noise_std=0.0).to(self.device)
                y = self._ensure_2d_y(y)

                # Save raw tensors to CPU.
                raw_history[name] = (X_raw.cpu(), y.cpu())

            self._provided_history_datasets = raw_history

            # Generate target init in raw space.
            X_init_norm = self._sample_design(self.cfg.n_init)
            X_init_raw = self._denormalize_x(X_init_norm)
            self.initial_target_X = X_init_raw

            # Save.
            data = {
                "history": raw_history,
                "target_init_X": self.initial_target_X.cpu(),
            }
            os.makedirs(os.path.dirname(pkl_path) or ".", exist_ok=True)
            torch.save(data, pkl_path)

    def _fit_target_gp(self) -> GPModel:
        """
        Fit GP on current target data.
        """
        target = self.manager.get_target_data()
        y_train = target.y.view(-1)

        gp = GPModel(device=self.device)
        gp.fit(
            target.X.detach().cpu(),
            y_train.detach().cpu(),
            n_iter=self.cfg.target_gp.n_iter,
            lr=self.cfg.target_gp.lr,
            n_restarts=self.cfg.target_gp.n_restarts,
            use_robust_init=False,
        )
        return gp

    def _compute_kendall_weights(self) -> Tuple[Dict[str, float], torch.Tensor]:
        """
        Compute Kendall weights on current target data.
        """
        from utils.Kendall import calculate_kendall_tau

        target = self.manager.get_target_data()
        X_target = target.X
        y_target = target.y.view(-1)

        taus_by_name: Dict[str, float] = {}
        taus_list: List[float] = []
        for name in self.manager.list_history_names():
            mse_model, _ = self.manager.get_history_models(name)
            if mse_model is None:
                raise RuntimeError(f"history model for '{name}' is not set")
            mean, _, cov = mse_model.predict(X_target)
            tau = calculate_kendall_tau(
                X_target, y_target, mean, cov, device=self.device
            )
            tau_value = float(tau)
            taus_by_name[name] = tau_value
            self.manager.set_history_kendall(name, tau_value)
            taus_list.append(tau_value)

        weights_tensor = torch.tensor(
            taus_list, device=self.device, dtype=torch.float32
        )
        return taus_by_name, weights_tensor

    def _build_fused_model(
        self, target_gp: GPModel, history_weights: torch.Tensor
    ) -> ModelMixture:
        """
        Build fused mixture model.
        """
        names = self.manager.list_history_names()
        experts: List[object] = [target_gp]
        for name in names:
            _, rank_model = self.manager.get_history_models(name)
            if rank_model is None:
                raise RuntimeError(f"rank model for '{name}' is not set")
            experts.append(rank_model)

        weights = torch.cat(
            [
                torch.ones(1, device=self.device, dtype=torch.float32),
                history_weights.to(self.device, dtype=torch.float32).view(-1),
            ],
            dim=0,
        )

        # Build calibration set from history and target points.
        X_hist_all: List[torch.Tensor] = []
        for name in self.manager.list_history_names():
            hist = self.manager.get_history_task(name)
            if hist.X.numel() > 0:
                X_hist_all.append(hist.X)

        target = self.manager.get_target_data()
        X_target_all = target.X

        if len(X_hist_all) > 0:
            X_calib_all = torch.cat([*X_hist_all, X_target_all], dim=0)
        else:
            X_calib_all = X_target_all

        n_calib_all = int(X_calib_all.size(0))
        if n_calib_all <= int(self.cfg.calibration_size):
            X_calib = X_calib_all
        else:
            idx = torch.randperm(n_calib_all, generator=self._rng)[
                : int(self.cfg.calibration_size)
            ]
            X_calib = X_calib_all.index_select(0, idx.to(X_calib_all.device))

        for m in experts:
            if hasattr(m, "set_prediction_calibration"):
                m.set_prediction_calibration(X_calib)

        return ModelMixture(models=experts, weights=weights, device=self.device)

    @torch.no_grad()
    def _select_anchors(self, taus: Dict[str, float]) -> Optional[torch.Tensor]:
        """
        Build anchors from history data.
        """
        anchors_all: List[torch.Tensor] = []
        k_per_task = int(self.cfg.acq.anchors.history_topk_per_task)

        for name in self.manager.list_history_names():
            hist = self.manager.get_history_task(name)
            n = int(hist.X.size(0))
            if n <= 0:
                continue
            k = min(k_per_task, n)

            # Task Kendall tau.
            tau = taus.get(name, 0.0)

            # Positive tau picks top values, negative tau picks bottom.
            largest = True
            if tau < 0:
                largest = False

            top_idx = torch.topk(hist.y.view(-1), k=k, largest=largest).indices
            anchors_all.append(hist.X[top_idx])

        if len(anchors_all) == 0:
            return None
        return torch.cat(anchors_all, dim=0)

    def step(
        self, iteration: int
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, float]]:
        """
        Run one BO iteration.
        """
        # Fit GP on current data.
        target_gp = self._fit_target_gp()

        # Compute Kendall weights.
        taus, hist_weights = self._compute_kendall_weights()

        # Build fused model.
        fused = self._build_fused_model(
            target_gp=target_gp, history_weights=hist_weights
        )

        # Select anchors for Rank Entropy term.
        anchors = self._select_anchors(taus)
        if self.cfg.acq.gamma_mode == "schedule":
            gamma_t = gamma_schedule(iteration, gamma0=self.cfg.acq.gamma0).to(
                self.device
            )
        else:
            gamma_t = torch.as_tensor(
                self.cfg.acq.gamma, device=self.device, dtype=torch.float32
            )

        # Maximize acquisition to get next point.
        x_next = maximize_acquisition_adam_multistart(
            mixture_model=fused,
            beta_t=self.cfg.acq.beta_t,
            gamma_t=gamma_t,
            steps=self.cfg.acq.optimizer.steps,
            lr=self.cfg.acq.optimizer.lr,
            n_restarts=self.cfg.acq.optimizer.n_restarts,
            anchors=anchors,
            print_every=self.cfg.acq.optimizer.print_every,
            dim=self.cfg.dim,
            bounds=self.cfg.bounds,
            device=self.device,
            dtype=torch.float32,
            seed=self.cfg.seed + iteration,
        )

        x_next_raw = self._denormalize_x(x_next)
        y_next_raw = self._ensure_2d_y(self._evaluate_target_raw(x_next_raw))

        target = self.manager.get_target_data()
        X_all = torch.cat([target.X, x_next.unsqueeze(0)], dim=0)
        if self._target_y_raw is None:
            raise RuntimeError("target raw y is not initialized")
        y_raw_all = torch.cat([self._target_y_raw, y_next_raw], dim=0)
        self._target_y_raw = y_raw_all.detach().clone()

        if self.cfg.normalize_y:
            y_fit_all = self._zscore_y_2d(y_raw_all)
        else:
            y_fit_all = y_raw_all
        self.manager.set_target_data(X_all, y_fit_all)

        return x_next, x_next_raw, y_next_raw, taus

    def run(self, result_pkl_path: Optional[str] = None) -> Dict[str, object]:
        """
        Run full BO loop.
        """
        torch.manual_seed(int(self.cfg.seed))
        self._eval_rng.manual_seed(int(self.cfg.seed))

        # Build and fit history models.
        history_datasets = self.build_history_datasets()
        self.fit_history_models(history_datasets)

        # Initialize target data.
        self.init_target()

        target = self.manager.get_target_data()
        if self._target_y_raw is None:
            raise RuntimeError("target raw y is not initialized")
        best_idx = int(torch.argmax(self._target_y_raw.view(-1)).item())
        best_y = float(self._target_y_raw.view(-1)[best_idx].item())
        best_x_norm = target.X[best_idx]
        best_x = self._denormalize_x(best_x_norm)
        print(f"[init] best_y={best_y:.6f}")

        trajectory = []
        current_best = float("-inf")
        n_init = int(self.cfg.n_init)
        for i in range(n_init):
            y_i = float(self._target_y_raw.view(-1)[i].item())
            if y_i > current_best:
                current_best = y_i
            trajectory.append({"iter": i, "best_y": current_best, "y_new": y_i})

        # BO loop.
        last_taus: Dict[str, float] = {}
        for it in range(1, self.cfg.n_iter + 1):
            print(f"Iter:{it}:")
            x_new_norm, x_new, y_new, taus = self.step(it)
            last_taus = taus

            target = self.manager.get_target_data()
            best_idx = int(torch.argmax(self._target_y_raw.view(-1)).item())
            best_y = float(self._target_y_raw.view(-1)[best_idx].item())
            best_x_norm = target.X[best_idx]
            best_x = self._denormalize_x(best_x_norm)

            y_new_val = float(y_new.view(-1)[0].item())
            if y_new_val > current_best:
                current_best = y_new_val
            eval_i = n_init - 1 + int(it)
            trajectory.append(
                {"iter": eval_i, "best_y": current_best, "y_new": y_new_val}
            )

            taus_str = ", ".join([f"{k}:{v:+.3f}" for k, v in taus.items()])
            print(
                f"[iter {it:02d}] y_new={float(y_new.item()):+.6f} best_y={best_y:+.6f} | taus=({taus_str})"
            )

        # Save trajectory.
        if result_pkl_path is not None:
            print(f"Saving optimization trajectory to {result_pkl_path}")
            os.makedirs(os.path.dirname(result_pkl_path) or ".", exist_ok=True)
            torch.save(trajectory, result_pkl_path)

        init_raw = self._init_target_raw
        return {
            "X_target_norm": self.manager.get_target_data().X,
            "X_target": self._denormalize_x(self.manager.get_target_data().X),
            "y_target": self.manager.get_target_data().y,
            "y_target_raw": self._target_y_raw,
            "best_x_norm": best_x_norm,
            "best_x": best_x,
            "best_y": best_y,
            "last_kendalls": last_taus,
            "init_X": None if init_raw is None else init_raw[0],
            "init_y": None if init_raw is None else init_raw[1],
            "trajectory": trajectory,
        }


def main():
    """
    Example entry point.
    """
    import importlib

    test_suite = "hartmann3"  # need define in test_function
    test_module = importlib.import_module(f"test_function.{test_suite}")
    build_history_tasks = getattr(test_module, "build_history_tasks")
    build_real_task = getattr(test_module, "build_real_task")

    # Script-style entry.

    cfg = OptimizerConfig(
        # Basic settings
        dim=3,  # dim of the problem
        bounds=(0.0, 1.0),  # Normalized bounds.
        raw_bounds=[(0.0, 1.0) for _ in range(3)],  # Raw bounds for denormalization.
        seed=0,
        obs_noise_std=1e-4,
        # Init and iterations
        n_init=5,
        n_iter=15,
        # Design sampling
        design="lhs",
        # Target GP
        target_gp=TargetGPConfig(
            n_iter=500,
            lr=1e-2,
            n_restarts=3,
        ),
        # Acquisition
        acq=AcquisitionConfig(
            beta_t=2.0,
            gamma_mode="schedule",
            gamma0=0.2,  # when gamma_mode="schedule"
            gamma=1.0,
            optimizer=AcquisitionOptimizerConfig(
                steps=80,
                lr=0.05,
                print_every=40,
                n_restarts=16,
            ),
            anchors=AnchorConfig(
                history_topk_per_task=1,  # top-k per history task
            ),
        ),
        # Calibration size cap
        calibration_size=20000,
        # History task data size
        default_history_n_data=200,
        history_tasks={},
    )

    # Build history and target tasks.
    task_bounds = cfg.raw_bounds if cfg.raw_bounds is not None else cfg.bounds
    history_tasks = build_history_tasks(dim=cfg.dim, bounds=task_bounds)
    target_task = build_real_task(dim=cfg.dim, bounds=task_bounds)

    history_value_model_cfg = EnsembleTrainConfig(
        hidden_dims=[128, 128],
        activation="gelu",
        num_models=5,
        steps=3000,
        lr=1e-2,
        weight_decay=1e-5,
        batch_size=64,
        loss_type="mse",
        use_amp=True,
        log_every=1000,
    )
    history_rank_model_cfg = EnsembleTrainConfig(
        hidden_dims=[128, 128],
        activation="gelu",
        num_models=5,
        steps=500,
        lr=1e-2,
        batch_size=64,
        weight_decay=1e-5,
        loss_type="rankcosine",
        list_size=80,  # 每个列表的长度
        lists_per_step=32,  # 一次step多少batch，相当于生成了 lists_per_step*batch_size 个列表
        use_amp=True,
        log_every=100,
    )
    default_history_n_data = int(cfg.default_history_n_data)
    cfg = replace(
        cfg,
        history_tasks={
            task.name: HistoryTaskConfig(
                n_data=default_history_n_data,
                value_model=replace(history_value_model_cfg),
                rank_model=replace(history_rank_model_cfg),
            )
            for task in history_tasks
        },
    )

    # Optional history datasets.
    history_datasets = None

    seeds = [0]

    base_dir = os.path.dirname(os.path.abspath(__file__))
    results_dir = os.path.join(base_dir, "test_results", test_suite)
    os.makedirs(results_dir, exist_ok=True)

    # generate history datasets and initial data for each seed
    # delete this if use existing history datasets
    for seed in seeds:
        print(f"\n{'=' * 40}")
        print(f"Generating PKL with SEED={seed}")
        print(f"{'=' * 40}")

        cfg_i = replace(cfg, seed=int(seed))

        bo = TransferRankBayesOpt(
            history_tasks=history_tasks,
            target_task=target_task,
            config=cfg_i,
            history_datasets=history_datasets,
        )

        pkl_path = os.path.join(results_dir, f"{test_suite}_data_seed{cfg_i.seed}.pkl")
        bo.generate_or_load_data(pkl_path, load_from_pkl=False)

    for seed in seeds:
        print(f"\n{'=' * 40}")
        print(f"Running Optimization with SEED={seed}")
        print(f"{'=' * 40}")

        cfg_i = replace(cfg, seed=int(seed))

        bo = TransferRankBayesOpt(
            history_tasks=history_tasks,
            target_task=target_task,
            config=cfg_i,
            history_datasets=history_datasets,
        )

        pkl_path = os.path.join(results_dir, f"{test_suite}_data_seed{cfg_i.seed}.pkl")
        bo.generate_or_load_data(pkl_path, load_from_pkl=bool(os.path.exists(pkl_path)))

        result_pkl_path = os.path.join(
            results_dir, f"{test_suite}_result_seed{cfg_i.seed}.pkl"
        )
        result = bo.run(result_pkl_path=result_pkl_path)

        print(f"\n=== final result (seed={seed}) ===")
        print("best_y:", result["best_y"])
        print("best_x:", result["best_x"].detach().cpu().numpy())
        print("last_kendalls:", result["last_kendalls"])


if __name__ == "__main__":
    main()
