from __future__ import annotations

import math
import random
from dataclasses import replace
from typing import Dict, List, Any, Tuple

import numpy as np

from .benchmark import BenchmarkConfig
from .sweep import evaluate_algo_on_env


def _random_combo(space: Dict[str, List[Any]], rng: random.Random) -> Dict[str, Any]:
    return {k: rng.choice(vs) for k, vs in space.items()}


def _combo_key(d: Dict[str, Any]) -> str:
    return "|".join(f"{k}={repr(d[k])}" for k in sorted(d.keys()))


def tpe_discrete(
    base_cfg: BenchmarkConfig,
    algo_name: str,
    space: Dict[str, List[Any]],
    n_init: int = 8,
    n_iter: int = 32,
    gamma: float = 0.2,
    seed: int = 0,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    rng = random.Random(seed)
    trials: List[Dict[str, Any]] = []
    seen = set()

    def eval_params(p: Dict[str, Any]) -> float:
        cfg = replace(base_cfg)
        cfg.algos = [algo_name]
        cfg.algo_params = {algo_name: p}
        finals = evaluate_algo_on_env(cfg, algo_name)
        loss = float(np.mean(finals))
        return loss

    domain_size = 1
    for vs in space.values():
        domain_size *= len(vs)
    n_init = min(n_init, domain_size)

    while len(trials) < n_init:
        cand = _random_combo(space, rng)
        key = _combo_key(cand)
        if key in seen:
            continue
        seen.add(key)
        loss = eval_params(cand)
        trials.append({"params": cand, "loss": loss})

    for _ in range(n_iter):
        losses = np.array([t["loss"] for t in trials], dtype=float)
        k = max(1, int(math.ceil(gamma * len(losses))))
        idx_sorted = np.argsort(losses)
        good_idx = set(idx_sorted[:k].tolist())
        bad_idx = set(idx_sorted[k:].tolist())

        counts_good: Dict[str, Dict[Any, int]] = {p: {} for p in space}
        counts_bad: Dict[str, Dict[Any, int]] = {p: {} for p in space}
        for i, tr in enumerate(trials):
            target = counts_good if i in good_idx else counts_bad
            for p, v in tr["params"].items():
                target[p][v] = target[p].get(v, 0) + 1

        proposal: Dict[str, Any] = {}
        for p, values in space.items():
            V = len(values)
            Nl = sum(counts_good[p].values()) if counts_good[p] else 0
            Ng = sum(counts_bad[p].values()) if counts_bad[p] else 0
            best_v = None
            best_ratio = -1.0
            for v in values:
                l = (counts_good[p].get(v, 0) + 1.0) / (Nl + V)
                g = (counts_bad[p].get(v, 0) + 1.0) / (Ng + V)
                ratio = l / g
                if ratio > best_ratio:
                    best_ratio = ratio
                    best_v = v
            proposal[p] = best_v

        key = _combo_key(proposal)
        if key in seen:
            weights = []
            combos = []
            for _ in range(200):
                c = _random_combo(space, rng)
                kk = _combo_key(c)
                if kk in seen:
                    continue
                prod = 1.0
                for p, v in c.items():
                    V = len(space[p])
                    Nl = sum(counts_good[p].values()) if counts_good[p] else 0
                    Ng = sum(counts_bad[p].values()) if counts_bad[p] else 0
                    l = (counts_good[p].get(v, 0) + 1.0) / (Nl + V)
                    g = (counts_bad[p].get(v, 0) + 1.0) / (Ng + V)
                    prod *= (l / g)
                weights.append(prod)
                combos.append(c)
            if combos:
                s = sum(weights)
                if s <= 0:
                    proposal = combos[0]
                else:
                    r = rng.random() * s
                    acc = 0.0
                    for w, c in zip(weights, combos):
                        acc += w
                        if acc >= r:
                            proposal = c
                            break
            else:
                while True:
                    c = _random_combo(space, rng)
                    kk = _combo_key(c)
                    if kk not in seen:
                        proposal = c
                        break

        seen.add(_combo_key(proposal))
        loss = eval_params(proposal)
        trials.append({"params": proposal, "loss": loss})

    best = min(trials, key=lambda t: t["loss"])["params"]
    return best, trials
