from typing import Dict, Iterable, List, Set, Tuple

from llm_mcts.mcts_algo.linucb.data_types import NodeIdentifier


class LinUCBNodeIndices:
    """
    Node Index handler, responsible for adding and retrieving node information

    We use the following indices convention:
    The indices start from GEN nodes for all the LLM, followed by generated nodes sorted by node indices.

    For example:
    (LLM A GEN, LLM B GEN, LLM A s0, LLMB s0, LLMB s1, ...)

    The NodeIdentifier corresponding to those nodes are:
    ("LLM A", "LLM B", 0, 1, 2, ...)

    i.e. GEN nodes are represented by str and the child nodes are represented by integer indices.

    The order of generated nodes (s0,s1,... nodes) depends on how LinUCB picks the node so is different for different problems.
    """

    def __init__(self, model_names: List[str]):
        self.indices: Dict[Tuple[int, str], Set[int]] = {
            (model_idx, model): set() for model_idx, model in enumerate(model_names)
        }
        self.next_idx: int = 0

    def add_new_node(self, model_name: str) -> None:
        for idx, model in self.indices:
            if model_name == model:
                self.indices[(idx, model)].add(self.next_idx)
                self.next_idx += 1
                return

        raise RuntimeError(f"model_name {model_name} invalid.")

    def get_one_hot_idxs(self, node_identifier: NodeIdentifier) -> List[int]:
        """
        Returns the list of indexes for one-hot feature vector for GEN node and generated node
        """
        # Python dict order is preserved after 3.7, so we rely on that here
        for (model_idx, model), idxs in self.indices.items():
            if isinstance(node_identifier, int) and node_identifier in idxs:
                return [model_idx, self.num_gen_nodes + node_identifier]
            elif isinstance(node_identifier, str) and node_identifier == model:
                return [model_idx]

        raise RuntimeError(
            f"invalid node_idx {node_identifier}, {type(node_identifier)}"
        )

    def child_idx(self) -> Iterable[int]:
        for i in range(self.total_dim - self.num_gen_nodes):
            yield i

    def get_model_name(self, model_idx: int) -> str:
        for idx, model_name in self.indices:
            if idx == model_idx:
                return model_name

        raise RuntimeError(
            f"model_idx {model_idx} out of bounds for number of models {len(self.indices)}"
        )

    def _get_all_idxs(self) -> Iterable[NodeIdentifier]:
        for _model_idx, model in self.indices:
            yield model
        for i in range(self.total_dim - self.num_gen_nodes):
            yield i

    @property
    def num_gen_nodes(self) -> int:
        return len(self.indices)

    @property
    def total_dim(self) -> int:
        # The number of GEN nodes + total number of generated nodes
        return self.num_gen_nodes + self.next_idx
