from dataclasses import dataclass
from functools import total_ordering
from typing import Any, List, Optional, Set, Tuple

from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.node_ranker.base import Ranker
from llm_mcts.node_ranker.submission import Submission
from llm_mcts.tasks.base import Task


@total_ordering
@dataclass
class SWEBenchSimpleScore:
    llm_judge_score: float
    node_depth: int

    def __gt__(self, other: "SWEBenchSimpleScore") -> bool:
        if self.llm_judge_score == other.llm_judge_score:
            return self.node_depth < other.node_depth
        return self.llm_judge_score > other.llm_judge_score

    def __eq__(self, other) -> bool:
        assert isinstance(other, SWEBenchSimpleScore)
        return (self.llm_judge_score == other.llm_judge_score) and (
            self.node_depth == other.node_depth
        )


class SWEBenchSimpleRanker(Ranker):
    def top_k_predictions(self, mcts_result: MCTSResult) -> List[Node]:
        nodes = self.sorted_nodes_by_score(mcts_result)
        return nodes

    def sorted_nodes_by_score(self, mcts_result: MCTSResult) -> List[Node]:
        score_node_pairs: List[Tuple[SWEBenchSimpleScore, Node]] = list(
            mcts_result.map_and_tolist(
                lambda node: (
                    self.score(node),
                    node,
                )
            ),
        )

        score_node_pairs.sort(key=lambda x: x[0], reverse=True)
        return list(map(lambda x: x[1], score_node_pairs))

    def score(self, node: Node) -> Optional[SWEBenchSimpleScore]:
        now = node
        node_depth = 0
        while not now.is_root():
            node_depth += 1
            now = now.parent
        llm_judge_score = (
            node.eval_results[0].score if node.eval_results is not None else 0
        )
        return SWEBenchSimpleScore(
            llm_judge_score=llm_judge_score,
            node_depth=node_depth,
        )
