import importlib
import os
from dataclasses import replace
from typing import List
import numpy as np
import torch
from test_function.task_spec import TaskSpec
from opt_final import (
    AcquisitionConfig,
    AcquisitionOptimizerConfig,
    AnchorConfig,
    EnsembleTrainConfig,
    HistoryTaskConfig,
    OptimizerConfig,
    TargetGPConfig,
    TransferRankBayesOpt,
)


def _tau_tag(tau: float) -> str:
    s = f"{float(tau):.1f}"
    return s.replace("-", "m").replace(".", "p")


def _bounds_to_lows_highs(
    bounds, dim: int, *, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
    if (
        isinstance(bounds, tuple)
        and len(bounds) == 2
        and not isinstance(bounds[0], (tuple, list))
        and not isinstance(bounds[1], (tuple, list))
    ):
        low, high = float(bounds[0]), float(bounds[1])
        lows = torch.full((dim,), low, device=device, dtype=dtype)
        highs = torch.full((dim,), high, device=device, dtype=dtype)
        return torch.stack([lows, highs], dim=1)
    if isinstance(bounds, (tuple, list)) and len(bounds) == 1 and dim > 1:
        b = bounds[0]
        low, high = float(b[0]), float(b[1])
        lows = torch.full((dim,), low, device=device, dtype=dtype)
        highs = torch.full((dim,), high, device=device, dtype=dtype)
        return torch.stack([lows, highs], dim=1)
    lows = torch.tensor([float(b[0]) for b in bounds], device=device, dtype=dtype)
    highs = torch.tensor([float(b[1]) for b in bounds], device=device, dtype=dtype)
    return torch.stack([lows, highs], dim=1)


def _sample_uniform_raw(n: int, bounds, dim: int, seed: int) -> torch.Tensor:
    g = torch.Generator(device="cpu").manual_seed(int(seed))
    bd = _bounds_to_lows_highs(
        bounds, dim=dim, device=torch.device("cpu"), dtype=torch.float32
    )
    lows = bd[:, 0].view(1, dim)
    highs = bd[:, 1].view(1, dim)
    X01 = torch.rand((int(n), int(dim)), generator=g, dtype=torch.float32)
    return lows + (highs - lows) * X01


def _load_or_create_seed_data(
    pkl_path: str,
    *,
    bounds,
    dim: int,
    seed: int,
    n_init: int,
    n_hist: int,
):
    os.makedirs(os.path.dirname(pkl_path) or ".", exist_ok=True)
    if os.path.exists(pkl_path):
        try:
            return torch.load(pkl_path, map_location="cpu", weights_only=False)
        except TypeError:
            return torch.load(pkl_path, map_location="cpu")

    init_X_raw = _sample_uniform_raw(int(n_init), bounds, dim=int(dim), seed=int(seed))
    X_hist_raw = _sample_uniform_raw(
        int(n_hist), bounds, dim=int(dim), seed=int(seed) + 12345
    )
    y_placeholder = torch.zeros((int(n_hist), 1), dtype=torch.float32)
    data = {
        "target_init_X": init_X_raw,
        "history": {"history_1": (X_hist_raw, y_placeholder)},
    }
    torch.save(data, pkl_path)
    return data


def _random_rotation(dim: int, *, seed: int) -> torch.Tensor:
    if dim <= 1:
        return torch.eye(1, dtype=torch.float32)
    g = torch.Generator(device="cpu").manual_seed(int(seed))
    A = torch.randn((dim, dim), generator=g, dtype=torch.float32)
    Q, R = torch.linalg.qr(A)
    d = torch.sign(torch.diag(R))
    d[d == 0] = 1.0
    Q = Q * d
    return Q


def _apply_transform(
    X: torch.Tensor,
    *,
    lows: torch.Tensor,
    highs: torch.Tensor,
    perm: torch.Tensor,
    rot: torch.Tensor,
    flip: torch.Tensor,
    scale: torch.Tensor,
    shift: torch.Tensor,
) -> torch.Tensor:
    X = torch.as_tensor(X, dtype=torch.float32)
    device, dtype = X.device, X.dtype
    perm = perm.to(device=device)
    lows_p = lows.to(device=device, dtype=dtype).index_select(0, perm)
    highs_p = highs.to(device=device, dtype=dtype).index_select(0, perm)
    center_p = (lows_p + highs_p) * 0.5

    Xp = X.index_select(1, perm)
    Xc = Xp - center_p.view(1, -1)
    rot = rot.to(device=device, dtype=dtype)
    Xr = Xc @ rot.T
    flip = flip.to(device=device, dtype=dtype)
    scale = scale.to(device=device, dtype=dtype)
    shift = shift.to(device=device, dtype=dtype)
    Xt = (
        center_p.view(1, -1)
        + shift.view(1, -1)
        + scale.view(1, 1) * (Xr * flip.view(1, -1))
    )
    return torch.max(torch.min(Xt, highs_p.view(1, -1)), lows_p.view(1, -1))


def _count_inversions(a: np.ndarray) -> int:
    a = np.asarray(a, dtype=np.int64)
    n = int(a.size)
    if n <= 1:
        return 0

    buf = np.empty_like(a)

    def sort_count(lo: int, hi: int) -> int:
        if hi - lo <= 1:
            return 0
        mid = (lo + hi) // 2
        inv = sort_count(lo, mid) + sort_count(mid, hi)

        i, j, k = lo, mid, lo
        while i < mid and j < hi:
            if a[i] <= a[j]:
                buf[k] = a[i]
                i += 1
            else:
                buf[k] = a[j]
                j += 1
                inv += mid - i
            k += 1
        while i < mid:
            buf[k] = a[i]
            i += 1
            k += 1
        while j < hi:
            buf[k] = a[j]
            j += 1
            k += 1

        a[lo:hi] = buf[lo:hi]
        return inv

    return int(sort_count(0, n))


def kendall_tau_no_ties(y1: np.ndarray, y2: np.ndarray) -> float:
    y1 = np.asarray(y1, dtype=np.float64).reshape(-1)
    y2 = np.asarray(y2, dtype=np.float64).reshape(-1)
    if y1.shape != y2.shape:
        raise ValueError(f"shape mismatch: {y1.shape} vs {y2.shape}")
    n = int(y1.size)
    if n < 2:
        return float("nan")

    order = np.argsort(y1, kind="mergesort")
    y2o = y2[order]
    rank2 = np.argsort(
        np.argsort(y2o, kind="mergesort"), kind="mergesort"
    ).astype(np.int64, copy=False)

    inv = _count_inversions(rank2)
    return float(1.0 - (4.0 * inv) / (n * (n - 1)))


def _sample_uniform(
    n: int, bounds, dim: int, device: torch.device, seed: int
) -> torch.Tensor:
    g = torch.Generator(device="cpu").manual_seed(int(seed))
    bd = _bounds_to_lows_highs(
        bounds, dim=dim, device=torch.device("cpu"), dtype=torch.float32
    )
    lows = bd[:, 0]
    highs = bd[:, 1]
    X01 = torch.rand((n, dim), generator=g, dtype=torch.float32)
    X = lows.view(1, dim) + (highs - lows).view(1, dim) * X01
    return X.to(device=device)


def _zscore_with_stats(y: torch.Tensor, mean: float, std: float) -> torch.Tensor:
    return (y - float(mean)) / (float(std) + 1e-12)


def _fit_mix_weight_for_tau(
    desired_tau: float,
    y_target: np.ndarray,
    y_alt: np.ndarray,
    sign: float,
    grid_size: int = 41,
) -> float:
    if abs(desired_tau) >= 0.999:
        return 1.0

    best_w = 0.0
    best_err = float("inf")
    for w in np.linspace(0.0, 1.0, int(grid_size)):
        y_hist = float(w) * sign * y_target + (1.0 - float(w)) * y_alt
        tau = kendall_tau_no_ties(y_target, y_hist)
        err = abs(float(tau) - float(desired_tau))
        if err < best_err:
            best_err = err
            best_w = float(w)
    return float(best_w)


def estimate_kendall_tau_between_tasks(
    target,
    history,
    n_points: int = 20000,
    seed: int = 0,
    device: torch.device | None = None,
) -> float:
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X = _sample_uniform(
        n=int(n_points),
        bounds=target.bounds,
        dim=int(target.dim),
        device=device,
        seed=int(seed),
    )
    y_target = target.evaluate(X, noise_std=0.0).view(-1).detach().cpu().numpy()
    y_hist = history.evaluate(X, noise_std=0.0).view(-1).detach().cpu().numpy()
    return kendall_tau_no_ties(y_target, y_hist)


def build_correlated_tasks(
    target_task,
    alt_task,
    taus,
    calib_n: int = 20000,
    seed: int = 0,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X_calib = _sample_uniform(
        n=int(calib_n),
        bounds=target_task.bounds,
        dim=int(target_task.dim),
        device=device,
        seed=int(seed),
    )

    y_t_calib = target_task.evaluate(X_calib, noise_std=0.0).view(-1)
    y_a_calib = alt_task.evaluate(X_calib, noise_std=0.0).view(-1)

    t_mean = float(y_t_calib.mean().item())
    t_std = float(y_t_calib.std().item())
    a_mean = float(y_a_calib.mean().item())
    a_std = float(y_a_calib.std().item())

    y_t_np = _zscore_with_stats(y_t_calib, t_mean, t_std).detach().cpu().numpy()
    y_a_np = _zscore_with_stats(y_a_calib, a_mean, a_std).detach().cpu().numpy()

    tasks: List[TaskSpec] = []
    for tau in list(taus):
        tau_f = float(tau)
        sign = 1.0 if tau_f >= 0.0 else -1.0
        w = _fit_mix_weight_for_tau(tau_f, y_t_np, y_a_np, sign=sign)

        def _make_objective_closure(w_val, sign_val):
            def _objective(X: torch.Tensor) -> torch.Tensor:
                X = torch.as_tensor(X, dtype=torch.float32)
                val_t = target_task.evaluate(X, noise_std=0.0).view(-1)
                val_a = alt_task.evaluate(X, noise_std=0.0).view(-1)

                norm_t = _zscore_with_stats(val_t, t_mean, t_std)
                norm_a = _zscore_with_stats(val_a, a_mean, a_std)

                y = float(w_val) * float(sign_val) * norm_t + (
                    1.0 - float(w_val)
                ) * norm_a
                return y.unsqueeze(-1)

            return _objective

        tasks.append(
            TaskSpec(
                name=f"history_tau_{tau_f:.1f}",
                dim=target_task.dim,
                bounds=target_task.bounds,
                objective=_make_objective_closure(w, sign),
            )
        )
    return tasks


def _build_auto_alt_task(
    target_task, *, seed: int, calib_n: int, n_basis: int, n_trials: int
):
    from test_function.task_spec import TaskSpec

    dim = int(target_task.dim)
    bd = _bounds_to_lows_highs(
        target_task.bounds, dim=dim, device=torch.device("cpu"), dtype=torch.float32
    )
    lows = bd[:, 0]
    highs = bd[:, 1]
    widths = highs - lows

    X_calib = _sample_uniform_raw(
        int(calib_n), target_task.bounds, dim=dim, seed=int(seed)
    )
    y_target = target_task.evaluate(X_calib, noise_std=0.0).view(-1).detach().cpu()
    y_target_np = y_target.numpy()

    g = torch.Generator(device="cpu").manual_seed(int(seed) + 17)
    perms = []
    rots = []
    flips = []
    scales = []
    shifts = []
    basis_mu = []
    basis_std = []
    basis_z_np = []

    for i in range(int(n_basis)):
        perm = torch.randperm(dim, generator=g, dtype=torch.int64)
        flip = torch.where(torch.rand((dim,), generator=g) < 0.5, -1.0, 1.0).to(
            dtype=torch.float32
        )
        scale = torch.exp((torch.rand((), generator=g) * 2.0 - 1.0) * 0.25).to(
            dtype=torch.float32
        )
        shift = (torch.rand((dim,), generator=g) * 2.0 - 1.0) * (0.25 * widths)
        rot = _random_rotation(dim, seed=int(seed) + 1000 + i)

        X_t = _apply_transform(
            X_calib,
            lows=lows,
            highs=highs,
            perm=perm,
            rot=rot,
            flip=flip,
            scale=scale,
            shift=shift,
        )
        y_i = target_task.evaluate(X_t, noise_std=0.0).view(-1).detach().cpu()
        mu = y_i.mean()
        std = y_i.std()
        z = (y_i - mu) / (std + 1e-12)

        perms.append(perm)
        rots.append(rot)
        flips.append(flip)
        scales.append(scale)
        shifts.append(shift)
        basis_mu.append(mu)
        basis_std.append(std)
        basis_z_np.append(z.numpy())

    Z_np = np.stack(basis_z_np, axis=0)
    Z = torch.from_numpy(Z_np).to(dtype=torch.float32).t().contiguous()

    best_alpha = None
    best_tau_abs = float("inf")
    for _ in range(int(n_trials)):
        alpha = torch.randn((int(n_basis),), generator=g, dtype=torch.float32)
        y_alt = (Z @ alpha).numpy()
        tau = float(kendall_tau_no_ties(y_target_np, y_alt))
        tau_abs = abs(tau)
        if tau_abs < best_tau_abs:
            best_tau_abs = tau_abs
            best_alpha = alpha.detach().clone()

    if best_alpha is None:
        best_alpha = torch.zeros((int(n_basis),), dtype=torch.float32)
        best_alpha[0] = 1.0

    mus = torch.stack([m.to(dtype=torch.float32) for m in basis_mu], dim=0)
    stds = torch.stack([s.to(dtype=torch.float32) for s in basis_std], dim=0)
    alpha = best_alpha.to(dtype=torch.float32)

    def _objective(X: torch.Tensor) -> torch.Tensor:
        X = torch.as_tensor(X, dtype=torch.float32)
        device, dtype = X.device, X.dtype
        out = torch.zeros((X.size(0),), device=device, dtype=dtype)
        mus_d = mus.to(device=device, dtype=dtype)
        stds_d = stds.to(device=device, dtype=dtype)
        alpha_d = alpha.to(device=device, dtype=dtype)
        lows_d = lows.to(device=device, dtype=dtype)
        highs_d = highs.to(device=device, dtype=dtype)

        for i in range(int(n_basis)):
            Xt = _apply_transform(
                X,
                lows=lows_d,
                highs=highs_d,
                perm=perms[i],
                rot=rots[i],
                flip=flips[i],
                scale=scales[i],
                shift=shifts[i],
            )
            yi = target_task.evaluate(Xt, noise_std=0.0).view(-1)
            zi = (yi - mus_d[i]) / (stds_d[i] + 1e-12)
            out = out + alpha_d[i] * zi

        return out.unsqueeze(-1)

    return TaskSpec(
        name="alt_auto", dim=dim, bounds=target_task.bounds, objective=_objective
    )


def main():
    test_suite = "schwefel"
    test_module = importlib.import_module(f"test_function.{test_suite}")
    build_real_task = getattr(test_module, "build_real_task")

    from test_function.task_spec import TaskSpec

    default_history_n_data = 200
    taus: List[float] = [-1.0, -0.5, 0.0, 0.5, 1.0]
    seeds = [0]

    cfg = OptimizerConfig(
        dim=6,
        bounds=(0.0, 1.0),
        raw_bounds=[(-500, 500) for _ in range(6)],
        seed=0,
        obs_noise_std=1e-4,
        n_init=5,
        n_iter=4,
        design="lhs",
        target_gp=TargetGPConfig(
            n_iter=80,
            lr=1e-2,
            n_restarts=1,
        ),
        acq=AcquisitionConfig(
            beta_t=2.0,
            gamma_mode="schedule",
            gamma0=0.2,
            gamma=1.0,
            optimizer=AcquisitionOptimizerConfig(
                steps=25,
                lr=0.05,
                print_every=40,
                n_restarts=6,
            ),
            anchors=AnchorConfig(
                history_topk_per_task=1,
            ),
        ),
        calibration_size=50000,
        default_history_n_data=int(default_history_n_data),
        history_tasks={},
    )

    task_bounds = cfg.raw_bounds if cfg.raw_bounds is not None else cfg.bounds
    target_task = build_real_task(dim=cfg.dim, bounds=task_bounds)
    alt_task = _build_auto_alt_task(
        target_task,
        seed=0,
        calib_n=3000,
        n_basis=6,
        n_trials=64,
    )
    base_dir = os.path.dirname(os.path.abspath(__file__))
    results_dir = os.path.join(base_dir, "test_results", test_suite)
    os.makedirs(results_dir, exist_ok=True)
    seed_data_cache = {}

    history_value_model_cfg = EnsembleTrainConfig(
        hidden_dims=[128, 128],
        activation="gelu",
        num_models=5,
        steps=200,
        lr=1e-2,
        weight_decay=1e-5,
        batch_size=64,
        loss_type="mse",
        use_amp=True,
        log_every=100,
    )
    history_rank_model_cfg = EnsembleTrainConfig(
        hidden_dims=[128, 128],
        activation="gelu",
        num_models=5,
        steps=120,
        lr=1e-2,
        batch_size=64,
        weight_decay=1e-5,
        loss_type="listnet",
        list_size=80,
        lists_per_step=32,
        use_amp=True,
        log_every=50,
    )

    for tau in taus:
        tag = _tau_tag(tau)
        correlated = build_correlated_tasks(
            target_task=target_task,
            alt_task=alt_task,
            taus=[float(tau)],
            calib_n=3000,
            seed=0,
        )
        if len(correlated) != 1:
            raise RuntimeError("expected exactly one history task")

        base_hist = correlated[0]
        history_task = TaskSpec(
            name="history_1",
            dim=base_hist.dim,
            bounds=base_hist.bounds,
            objective=base_hist.objective,
        )
        history_tasks = [history_task]
        tau_est = estimate_kendall_tau_between_tasks(
            target_task, history_task, n_points=3000, seed=0
        )
        print(f"[history] desired_tau={tau:+.1f} estimated_tau={tau_est:+.4f}")

        cfg_tau = replace(
            cfg,
            history_tasks={
                history_task.name: HistoryTaskConfig(
                    n_data=default_history_n_data,
                    value_model=replace(history_value_model_cfg),
                    rank_model=replace(history_rank_model_cfg),
                )
            },
        )

        for seed in seeds:
            print(f"\n{'=' * 40}")
            print(f"Running Optimization with tau={tau:+.1f} seed={seed}")
            print(f"{'=' * 40}")

            cfg_i = replace(cfg_tau, seed=int(seed))

            pkl_path = os.path.join(
                results_dir, f"{test_suite}_data_seed{cfg_i.seed}.pkl"
            )
            if pkl_path not in seed_data_cache:
                seed_data_cache[pkl_path] = _load_or_create_seed_data(
                    pkl_path,
                    bounds=task_bounds,
                    dim=int(cfg_i.dim),
                    seed=int(cfg_i.seed),
                    n_init=int(cfg_i.n_init),
                    n_hist=int(default_history_n_data),
                )

            data = seed_data_cache[pkl_path]
            init_X_raw = data["target_init_X"]
            init_X_raw = torch.as_tensor(init_X_raw, dtype=torch.float32)
            if int(init_X_raw.size(0)) != int(cfg_i.n_init):
                raise ValueError(
                    f"target_init_X size mismatch: got {int(init_X_raw.size(0))}, expected cfg.n_init={int(cfg_i.n_init)} "
                    f"(file={pkl_path})"
                )
            raw_history = data["history"]
            if "history_1" in raw_history:
                X_hist_raw = raw_history["history_1"][0]
            else:
                first_key = next(iter(raw_history.keys()))
                X_hist_raw = raw_history[first_key][0]

            X_hist_raw = torch.as_tensor(X_hist_raw, dtype=torch.float32)
            y_hist_raw = history_task.evaluate(X_hist_raw, noise_std=0.0).detach().cpu()
            print(
                f"[data] seed={cfg_i.seed} init_n={int(init_X_raw.size(0))} "
                f"hist_n={int(X_hist_raw.size(0))} hist_y(mean={float(y_hist_raw.mean()):+.4f}, std={float(y_hist_raw.std()):.4f})"
            )
            history_datasets = {"history_1": (X_hist_raw, y_hist_raw)}

            bo = TransferRankBayesOpt(
                history_tasks=history_tasks,
                target_task=target_task,
                config=cfg_i,
                history_datasets=history_datasets,
                initial_target_X=init_X_raw,
            )

            result_pkl_path = os.path.join(
                results_dir, f"{test_suite}_result_seed{cfg_i.seed}_t{tag}.pkl"
            )
            result = bo.run(result_pkl_path=result_pkl_path)

            print(f"\n=== final result (tau={tau:+.1f} seed={seed}) ===")
            print("best_y:", result["best_y"])
            print("best_x:", result["best_x"].detach().cpu().numpy())


if __name__ == "__main__":
    main()
