""" Calculate Embedding for each node in the graph, and cluster the nodes based on the embeddings. """

from overrides import overrides
from typing import (
    Text,
    List,
    Dict,
    Any
)
from .base_cluster_runner import BaseClusterRunner


class EmbeddingClusterRunner(BaseClusterRunner):
    def __init__(
        self,
        targets: List[int],
        model_name: Text,
    ):
        super().__init__(targets=targets)
        self._model_name = model_name
        
    @overrides
    def run(self, inputs: List[List[Dict[Text, Any]]]) -> List[List[Dict[Text, Any]]]:
        """ """
        
        def _prepare_prompts(item: Dict[Text, Any]) -> List[Text]:
            choices = item['selected'] + item['candidates']
            return [f"the {item['answer_type'].lower()} that is {c}" for c in choices]
        
        number_of_choices = []
        prompts = []
        
        for item in inputs:
            extension = _prepare_prompts(item)
            number_of_choices.append(len(extension))
            prompts.extend(extension)
            
        # now pass all of the prompts to the model
        pass