import dataclasses
import html
import math
from logging import getLogger
from pathlib import Path
from typing import Any, Hashable, Iterable, Iterator, Sequence

import treequest as tq
from joblib import Parallel, delayed, parallel_config
from treequest import ABMCTSM

from bd_mcts.search_algo.base import Result, SearchAlgo, Trial
from bd_mcts.token_sequence import TokenSequence

logger = getLogger(__name__)
_MAX_ASK_ATTEMPTS = 5
_WORKER_ALGO = None


@dataclasses.dataclass
class ABMCTSDConfig:
    models: list[str]
    gen_length: int
    num_func_eval_budget: int
    full_rollout: bool = True
    demask_schedule: list[float] = dataclasses.field(
        default_factory=lambda: [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
    )
    min_unmask_num: int = 3


@dataclasses.dataclass
class ABMCTSDState:
    token_sequence: TokenSequence
    clean_pred: str


def _format_masked_text(token_sequence: TokenSequence) -> str:
    parts: list[str] = []
    for segment in token_sequence.segments:
        if segment.kind == "text":
            parts.append(segment.content)
        elif segment.kind == "mask":
            parts.append("<mask>" * segment.repetition)
    return "".join(parts)


def _format_abmcts_state(state: ABMCTSDState) -> str:
    masked_text = _format_masked_text(state.token_sequence)
    safe_clean = html.escape(state.clean_pred)
    safe_masked = html.escape(masked_text)
    return (
        "<div><strong>mask_fraction:</strong> "
        f"{state.token_sequence.mask_fraction:.3f}</div>"
        "<div><strong>clean prediction:</strong>"
        f'<div style="white-space: pre-wrap;">{safe_clean}</div></div>'
        "<div><strong>partial masked text:</strong>"
        f'<div style="white-space: pre-wrap;">{safe_masked}</div></div>'
    )


def render_abmcts_tree_html(search_tree: Any, output_basename: str | Path) -> None:
    tq.render(
        search_tree,
        output_basename,
        format="html",
        state_formatter=_format_abmcts_state,
    )


def _calc_num_tokens_to_demask(
    trial,
    state,
    gen_length: int,
    demask_schedule: list[float],
    min_unmask_num: int,
) -> int | None:
    if gen_length <= 0:
        raise ValueError(f"gen_length must be positive, got {gen_length}")
    if not demask_schedule:
        return None

    # Work in integer mask counts to avoid float rounding (e.g. 507/768 ≈ 0.660156).
    parent_state = trial.parent_state
    if parent_state is None:
        parent_masks = gen_length
        parent_depth = 0
    else:
        parent_masks = int(parent_state.token_sequence.num_masks)
        parent_depth = int(state.tree.get_node(trial.node_to_expand).depth)

    schedule_targets = [
        max(0, min(gen_length, math.ceil(gen_length * float(frac))))
        for frac in demask_schedule
    ]

    if parent_depth >= len(schedule_targets):
        return None

    target_masks: int | None = None
    for idx in range(parent_depth, len(schedule_targets)):
        candidate = schedule_targets[idx]
        if parent_masks > candidate:
            target_masks = candidate
            break
    if target_masks is None:
        return None

    num_tokens_to_demask = parent_masks - target_masks
    if num_tokens_to_demask < int(min_unmask_num):
        num_tokens_to_demask = min(parent_masks, int(min_unmask_num))
    if num_tokens_to_demask <= 0:
        return None
    return num_tokens_to_demask


def _worker_init_abmctsm(config: dict[str, Any]) -> None:
    global _WORKER_ALGO
    _WORKER_ALGO = ABMCTSM(**config)


def _ask_worker(
    problem_id: Hashable,
    state: Any,
    actions: list[str],
    gen_length: int,
    demask_schedule: list[float],
    min_unmask_num: int,
    num_func_eval_budget: int,
    full_rollout: bool,
) -> tuple[Hashable, Any, Trial, list[Any]]:
    global _WORKER_ALGO
    if _WORKER_ALGO is None:
        _WORKER_ALGO = ABMCTSM()

    archived_states: list[Any] = []
    while True:
        for _ in range(_MAX_ASK_ATTEMPTS):
            state, trial = _WORKER_ALGO.ask(state=state, actions=actions)
            num_tokens_to_demask = _calc_num_tokens_to_demask(
                trial, state, gen_length, demask_schedule, min_unmask_num
            )
            if num_tokens_to_demask is None:
                continue

            if trial.parent_state is None:
                parent_token_seq = None
            else:
                parent_token_seq = trial.parent_state.token_sequence

            return (
                problem_id,
                state,
                Trial(
                    trial_id=trial.trial_id,
                    parent_token_seq=parent_token_seq,
                    num_tokens_to_demask=num_tokens_to_demask,
                    remaining_func_evals=num_func_eval_budget,
                    full_rollout=full_rollout,
                    action=trial.action,
                ),
                archived_states,
            )

        logger.warning(
            "ask failed after %s attempts due to the depth constraint failure. "
            "will archive this tree and create a new one...",
            _MAX_ASK_ATTEMPTS,
        )
        archived_states.append(state)
        state = _WORKER_ALGO.init_tree()


@dataclasses.dataclass(frozen=True)
class ParallelTrial:
    problem_id: Hashable
    trial: Trial


class ABMCTSD(SearchAlgo):
    def __init__(self, config: ABMCTSDConfig):
        self.actions = config.models
        self._abmcts_config = {
            "enable_pruning": False,
            "model_selection_strategy": "multiarm_bandit_thompson",
        }
        self.abmcts_m = ABMCTSM(**self._abmcts_config)
        self.abmcts_state = self.abmcts_m.init_tree()
        self.full_rollout = config.full_rollout
        self.num_func_eval_budget = config.num_func_eval_budget
        self.demask_schedule = config.demask_schedule
        self.min_unmask_num = config.min_unmask_num
        self.gen_length = config.gen_length

        self.state_archive: list[Any] = []

    def ask(self) -> Trial:
        for _ in range(_MAX_ASK_ATTEMPTS):
            abmcts_state, trial = self.abmcts_m.ask(
                state=self.abmcts_state, actions=self.actions
            )
            self.abmcts_state = abmcts_state

            num_tokens_to_demask = _calc_num_tokens_to_demask(
                trial,
                abmcts_state,
                self.gen_length,
                self.demask_schedule,
                self.min_unmask_num,
            )
            if num_tokens_to_demask is None:
                continue

            if trial.parent_state is None:
                parent_token_seq = None
            else:
                parent_token_seq = trial.parent_state.token_sequence

            return Trial(
                trial_id=trial.trial_id,
                parent_token_seq=parent_token_seq,
                num_tokens_to_demask=num_tokens_to_demask,
                remaining_func_evals=self.num_func_eval_budget,
                full_rollout=self.full_rollout,
                action=trial.action,
            )

        logger.warning(
            "ask failed after %s attempts due to the depth constraint failure. "
            "will archive this tree and create a new one...",
            _MAX_ASK_ATTEMPTS,
        )
        self.state_archive.append(self.abmcts_state)
        self.abmcts_state = self.abmcts_m.init_tree()
        return self.ask()

    def tell(self, result: Result) -> None:
        self.num_func_eval_budget -= result.num_func_evals

        new_state = ABMCTSDState(
            token_sequence=result.new_token_seq, clean_pred=result.clean_pred
        )
        self.abmcts_state = self.abmcts_m.tell(
            state=self.abmcts_state,
            trial_id=result.trial_id,
            result=(new_state, result.reward),
        )

    def top_k(self, k: int) -> list[tuple[str, float]]:
        candidates = []
        for archive in self.state_archive:
            candidates += [
                (node_state.clean_pred, score)
                for node_state, score in tq.top_k(
                    state=archive, algorithm=self.abmcts_m, k=k
                )
            ]
        candidates += [
            (node_state.clean_pred, score)
            for node_state, score in tq.top_k(
                state=self.abmcts_state, algorithm=self.abmcts_m, k=k
            )
        ]
        return sorted(candidates, key=lambda x: x[1], reverse=True)[:k]


class ABMCTSDParallel:
    def __init__(
        self, problems: dict[Hashable, ABMCTSD], *, ask_workers: int = 1
    ) -> None:
        if ask_workers < 1:
            raise ValueError("ask_workers must be >= 1")
        self._problems = problems
        self._ask_workers = ask_workers
        self._worker_config = self._resolve_worker_config(problems)

    def _resolve_worker_config(
        self, problems: dict[Hashable, ABMCTSD]
    ) -> dict[str, Any]:
        configs = {
            tuple(sorted(algo._abmcts_config.items())) for algo in problems.values()
        }
        if len(configs) > 1:
            raise ValueError(
                "ABMCTSDParallel requires identical ABMCTSM configs across problems."
            )
        if not configs:
            return {}
        return dict(configs.pop())

    def iter_ask(
        self, problem_ids: Sequence[Hashable] | None = None
    ) -> Iterator[ParallelTrial]:
        ids = list(problem_ids) if problem_ids is not None else list(self._problems)
        if not ids:
            return
        missing = [pid for pid in ids if pid not in self._problems]
        if missing:
            raise KeyError(f"unknown problem ids: {missing}")

        n_jobs = min(self._ask_workers, len(ids))
        if n_jobs <= 1:
            for problem_id in ids:
                algo = self._problems[problem_id]
                trial = algo.ask()
                yield ParallelTrial(problem_id=problem_id, trial=trial)
            return

        with (
            parallel_config(
                backend="loky",
                n_jobs=n_jobs,
                prefer="processes",
                initializer=_worker_init_abmctsm,
                initargs=(self._worker_config,),
            ),
            Parallel(return_as="generator_unordered") as parallel,
        ):
            results = parallel(
                delayed(_ask_worker)(
                    problem_id,
                    self._problems[problem_id].abmcts_state,
                    self._problems[problem_id].actions,
                    self._problems[problem_id].gen_length,
                    self._problems[problem_id].demask_schedule,
                    self._problems[problem_id].min_unmask_num,
                    self._problems[problem_id].num_func_eval_budget,
                    self._problems[problem_id].full_rollout,
                )
                for problem_id in ids
            )
            for problem_id, new_state, trial, archived_states in results:
                algo = self._problems[problem_id]
                algo.abmcts_state = new_state
                if archived_states:
                    algo.state_archive.extend(archived_states)
                yield ParallelTrial(problem_id=problem_id, trial=trial)

    def ask(self, problem_ids: Sequence[Hashable] | None = None) -> list[ParallelTrial]:
        return list(self.iter_ask(problem_ids))

    def tell(self, problem_id: Hashable, result: Result) -> None:
        self._problems[problem_id].tell(result)

    def tell_many(self, results: Iterable[tuple[Hashable, Result]]) -> None:
        for problem_id, result in results:
            self.tell(problem_id, result)
