from dataclasses import dataclass
from functools import total_ordering
from typing import Any, List, Optional, Set, Tuple

from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.mcts_scorer.default import DefaultScorer
from llm_mcts.node_ranker.base import Ranker
from llm_mcts.node_ranker.submission import Submission
from llm_mcts.tasks.base import Task


@total_ordering
@dataclass
class SimpleScore:
    score: float
    node_depth: int

    def __gt__(self, other: "SimpleScore") -> bool:
        if self.score == other.score:
            return self.node_depth < other.node_depth
        return self.score > other.score

    def __eq__(self, other) -> bool:
        assert isinstance(other, SimpleScore)
        return (self.score == other.score) and (self.node_depth == other.node_depth)


class SimpleRanker(Ranker):
    def __init__(self, scorer: Optional[MCTSScorer] = None):
        if scorer is not None:
            self.scorer = scorer
        else:
            self.scorer = DefaultScorer()

    def top_k_predictions(
        self, mcts_result: MCTSResult, k: Optional[int] = None
    ) -> List[Node]:
        nodes = self.sorted_nodes_by_score(mcts_result)
        return nodes[:k] if k is not None else nodes

    def sorted_nodes_by_score(self, mcts_result: MCTSResult) -> List[Node]:
        score_node_pairs: List[Tuple[SimpleScore, Node]] = list(
            mcts_result.map_and_tolist(
                lambda node: (
                    self.score(node),
                    node,
                )
            ),
        )

        score_node_pairs.sort(key=lambda x: x[0], reverse=True)
        return list(map(lambda x: x[1], score_node_pairs))

    def score(self, node: Node) -> SimpleScore:
        now = node
        node_depth = 0
        while not now.is_root():
            node_depth += 1
            now = now.parent

        score = self.scorer.get_score(node)
        return SimpleScore(score=score, node_depth=node_depth)
