from __future__ import annotations

import dataclasses
import math
from logging import getLogger

from bd_mcts.search_algo.base import Result, SearchAlgo, Trial
from bd_mcts.token_sequence import TokenSequence

logger = getLogger(__name__)


@dataclasses.dataclass
class DTSConfig:
    gen_length: int
    num_func_eval_budget: int
    full_rollout: bool = True
    demask_schedule: list[float] = dataclasses.field(
        default_factory=lambda: [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.2]
    )
    min_unmask_num: int = 3
    exploration_const: float = 1.0
    progressive_width_const: float = 2.0
    progressive_width_alpha: float = 0.7
    backup_lambda: float = math.inf
    max_selection_attempts: int = 20


@dataclasses.dataclass(eq=False)
class DTSNode:
    token_sequence: TokenSequence | None
    clean_pred: str | None
    parent: "DTSNode | None"
    children: list["DTSNode"] = dataclasses.field(default_factory=list)
    visit_count: int = 1
    value_est: float = 0.0
    is_exhausted: bool = False

    def is_terminal(self) -> bool:
        return self.token_sequence is not None and self.token_sequence.num_masks <= 0

    def update_value(self, value: float) -> None:
        self.value_est = value

    def is_root(self) -> bool:
        return self.parent is None

    @property
    def depth(self) -> int:
        """
        root depth = 0
        """
        answer = 0
        now = self
        while now.parent is not None:
            answer += 1
            now = now.parent
        return answer


class DTS(SearchAlgo):
    def __init__(self, config: DTSConfig) -> None:
        self.gen_length = config.gen_length
        self.num_func_eval_budget = config.num_func_eval_budget
        if not config.full_rollout:
            logger.warning(
                "DTS always uses full rollout; overriding full_rollout=False"
            )
        self.full_rollout = True
        self.demask_schedule = config.demask_schedule
        self.min_unmask_num = config.min_unmask_num
        self.exploration_const = config.exploration_const
        self.progressive_width_const = config.progressive_width_const
        self.progressive_width_alpha = config.progressive_width_alpha
        self.backup_lambda = config.backup_lambda
        self.max_selection_attempts = config.max_selection_attempts

        self.root = DTSNode(
            token_sequence=None,
            clean_pred=None,
            parent=None,
            visit_count=0,
        )
        self.pending_trials: dict[str, DTSNode] = {}
        self.results: list[tuple[str, float]] = []
        self.trial_id = 0

    def next_trial_id(self) -> str:
        self.trial_id += 1
        return str(self.trial_id)

    def ask(self) -> Trial:
        for _ in range(self.max_selection_attempts):
            node = self._select_expandable_node()
            if node is None:
                self._archive_and_reset_tree()
                continue

            num_tokens_to_demask = self.calc_num_tokens_to_demask(node)
            if num_tokens_to_demask is None or num_tokens_to_demask <= 0:
                self._mark_exhausted(node)
                continue

            trial_id = self.next_trial_id()
            self.pending_trials[trial_id] = node
            remaining_budget = (
                self.num_func_eval_budget if self.num_func_eval_budget > 0 else -1
            )
            parent_token_seq = node.token_sequence if node.parent is not None else None

            return Trial(
                trial_id=trial_id,
                parent_token_seq=parent_token_seq,
                num_tokens_to_demask=num_tokens_to_demask,
                remaining_func_evals=remaining_budget,
                full_rollout=self.full_rollout,
            )

        logger.warning(
            "ask failed to find an expandable node. will archive this tree and restart."
        )
        self._archive_and_reset_tree()
        return self.ask()

    def tell(self, result: Result) -> None:
        parent = self.pending_trials.pop(result.trial_id, None)
        if parent is None:
            raise RuntimeError(
                f"trial_id {result.trial_id} not found in pending trials"
            )

        self.num_func_eval_budget -= result.num_func_evals

        rollout_token_seqs = result.rollout_token_seqs
        rollout_clean_preds = result.rollout_clean_preds

        if self.full_rollout and not rollout_token_seqs:
            raise RuntimeError(
                "DTS expects rollout_token_seqs for full rollout. "
                "Populate Result.rollout_token_seqs with the rollout path."
            )

        if rollout_token_seqs:
            if rollout_clean_preds is not None and len(rollout_clean_preds) != len(
                rollout_token_seqs
            ):
                raise RuntimeError(
                    "rollout_clean_preds length must match rollout_token_seqs length"
                )

            if rollout_clean_preds is None:
                rollout_clean_preds = [None] * (len(rollout_token_seqs) - 1) + [
                    result.clean_pred
                ]

            prev = parent
            nodes: list[DTSNode] = []
            for seq, clean_pred in zip(rollout_token_seqs, rollout_clean_preds):
                node = DTSNode(
                    token_sequence=seq,
                    clean_pred=clean_pred,
                    parent=prev,
                    visit_count=1,
                    value_est=0.0,
                )
                if node.is_terminal():
                    node.is_exhausted = True
                prev.children.append(node)
                nodes.append(node)
                prev = node

            final_node = nodes[-1]
            final_node.update_value(result.reward)
            if final_node.is_terminal():
                final_node.is_exhausted = True
            else:
                logger.warning("DTS rollout did not reach a terminal node.")

            final_clean_pred = final_node.clean_pred
            if final_clean_pred is None:
                final_clean_pred = result.clean_pred

            self.results.append((final_clean_pred, result.reward))
            self._backup(final_node)
            return

        child = DTSNode(
            token_sequence=result.new_token_seq,
            clean_pred=result.clean_pred,
            parent=parent,
            visit_count=1,
            value_est=result.reward,
        )
        if child.is_terminal():
            child.is_exhausted = True
        parent.children.append(child)

        self.results.append((result.clean_pred, result.reward))
        self._backup(child)

    def top_k(self, k: int) -> list[tuple[str, float]]:
        return sorted(self.results, key=lambda x: x[1], reverse=True)[:k]

    def calc_num_tokens_to_demask(self, node: DTSNode) -> int | None:
        if node.parent is None or node.token_sequence is None:
            return int(self.gen_length * (1 - self.demask_schedule[0]))

        depth = node.depth + 1  # root depth is 0
        if depth - 1 < len(self.demask_schedule):
            target_frac = self.demask_schedule[depth - 1]
            parent_mask_frac = node.token_sequence.mask_fraction

            if parent_mask_frac > target_frac:
                return int(self.gen_length * (parent_mask_frac - target_frac))
            else:
                idx = depth - 1 + 1
                while idx < len(self.demask_schedule):
                    if parent_mask_frac > self.demask_schedule[idx]:
                        return int(self.gen_length * (parent_mask_frac - target_frac))
                    idx += 1
        return None

    def _select_expandable_node(self) -> DTSNode | None:
        node = self.root
        node.visit_count += 1
        while True:
            if self._can_expand(node):
                return node

            candidates = [child for child in node.children if not child.is_exhausted]
            if not candidates:
                self._mark_exhausted(node)
                return None

            node = self._select_child_ucb(node, candidates)
            node.visit_count += 1

    def _can_expand(self, node: DTSNode) -> bool:
        if node.is_exhausted:
            return False
        if node.is_terminal():
            node.is_exhausted = True
            return False
        if self.calc_num_tokens_to_demask(node) is None:
            return False
        return len(node.children) < self._max_children(node)

    def _max_children(self, node: DTSNode) -> int:
        effective_visits = max(1, node.visit_count)
        max_children = int(
            self.progressive_width_const
            * (effective_visits**self.progressive_width_alpha)
        )
        return max(1, max_children)

    def _select_child_ucb(self, parent: DTSNode, children: list[DTSNode]) -> DTSNode:
        parent_visits = max(1, parent.visit_count)
        best_child = children[0]
        best_score = -float("inf")
        for child in children:
            child_visits = max(1, child.visit_count)
            ucb_score = child.value_est + self.exploration_const * math.sqrt(
                math.log(parent_visits) / child_visits
            )
            if ucb_score > best_score:
                best_score = ucb_score
                best_child = child
        return best_child

    def _backup(self, node: DTSNode) -> None:
        current = node
        while current.parent is not None:
            parent = current.parent
            if parent.children:
                parent.update_value(self._aggregate_child_values(parent.children))
            current = parent

    def _aggregate_child_values(self, children: list[DTSNode]) -> float:
        values = [child.value_est for child in children]
        if not values:
            return 0.0
        if math.isinf(self.backup_lambda):
            return max(values)
        if self.backup_lambda == 0:
            return sum(values) / len(values)
        scaled = [self.backup_lambda * v for v in values]
        max_scaled = max(scaled)
        lse = max_scaled + math.log(sum(math.exp(v - max_scaled) for v in scaled))
        return lse / self.backup_lambda

    def _mark_exhausted(self, node: DTSNode) -> None:
        node.is_exhausted = True
        current = node.parent
        while current is not None:
            if all(child.is_exhausted for child in current.children):
                current.is_exhausted = True
            else:
                break
            current = current.parent

    def _archive_and_reset_tree(self) -> None:
        self.root = DTSNode(
            token_sequence=None,
            clean_pred=None,
            parent=None,
            visit_count=0,
        )
