import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, List, TypeVar

from llm_mcts.mcts_algo.node import Node

T = TypeVar("T")


@dataclass
class MCTSResult:
    root: Node

    def save(self, save_path: Path) -> None:
        pickle.dump(self, open(save_path, "wb"))

    @classmethod
    def load(cls, save_path: Path) -> "MCTSResult":
        return pickle.load(open(save_path, "rb"))

    def map_and_tolist(self, op: Callable[[Node], T]) -> List[T]:
        results = []
        results.append(op(self.root))
        for child in self.root.children:
            results += MCTSResult(child).map_and_tolist(op)

        return results
