from dataclasses import dataclass
from functools import total_ordering
from typing import Any, List, Optional, Set, Tuple

from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.mcts_algo.node import Node
from llm_mcts.mcts_scorer.base import MCTSScorer
from llm_mcts.node_ranker.base import Ranker
from llm_mcts.node_ranker.submission import Submission
from llm_mcts.tasks.base import Task


def list_to_tuple(nested_list):
    return tuple(
        list_to_tuple(item) if isinstance(item, list) else item for item in nested_list
    )


def tuple_to_list(nested_tuple):
    return [
        tuple_to_list(item) if isinstance(item, tuple) else item
        for item in nested_tuple
    ]


@total_ordering
@dataclass
class RegScore:
    transform_score: float
    node_depth: int

    def __gt__(self, other: "RegScore") -> bool:
        if self.transform_score == other.transform_score:
            return self.node_depth < other.node_depth
        return self.transform_score > other.transform_score

    def __eq__(self, other) -> bool:
        assert isinstance(other, RegScore)
        return (self.transform_score == other.transform_score) and (
            self.node_depth == other.node_depth
        )


class RegularizedRanker(Ranker):
    """
    We avoid overfitting to the input-output demos
    by penalizing node depth.
    """

    def __init__(self, scorer: MCTSScorer, task: Task):
        self.scorer = scorer
        self.task = task

    def top_k_predictions(
        self, mcts_result: MCTSResult, k: int = 3
    ) -> List[Submission]:
        nodes = self.sorted_nodes_by_score(mcts_result)
        return self.remove_dups(nodes, k=k)

    def sorted_nodes_by_score(self, mcts_result: MCTSResult) -> List[Node]:
        score_node_pairs: List[Tuple[RegScore, Node]] = list(
            filter(
                lambda x: x[0] is not None,
                mcts_result.map_and_tolist(
                    lambda node: (
                        self.score(node, scorer=self.scorer),
                        node,
                    )
                ),
            )  # Filter out non-transform action node
        )

        score_node_pairs.sort(key=lambda x: x[0], reverse=True)
        return list(map(lambda x: x[1], score_node_pairs))

    def remove_dups(self, nodes: List[Node], k: int) -> List[Submission]:
        assert k > 0

        submissions: List[Submission] = []
        test_outputs_archive: Set[Tuple[Any]] = set()
        for node in nodes:
            last_generation_result = node.next_prompt.get_last_generation_result()
            eval_results, score = self.task.evaluate_on_test(
                llm_answer=last_generation_result
            )
            if len(eval_results) == 0:
                continue

            outputs: List[Any] = []
            for eval_result in eval_results:
                output = eval_result.answer if eval_result.answer is not None else [[]]
                outputs.append(output)

            outputs_tuple = list_to_tuple(outputs)

            if outputs_tuple not in test_outputs_archive:
                test_outputs_archive.add(outputs_tuple)
                submissions.append(
                    Submission(
                        node=node,
                        outputs=outputs_tuple,
                        score=score,
                    )
                )

            if len(test_outputs_archive) == k:
                break

        return submissions

    def score(self, node: Node, scorer: MCTSScorer) -> Optional[RegScore]:
        if node.last_action != "transform":
            return None

        now = node
        node_depth = 0
        while not now.is_root():
            node_depth += 1
            now = now.parent
        return RegScore(
            transform_score=scorer.get_score(node=node),
            node_depth=node_depth,
        )
