import os
from dataclasses import dataclass
from typing import Any, NamedTuple

from kg_gen import KGGen

from structured_llmuq.utils.latent_encoder.set import LatentSet

from .set import SetEncoder


class KnowledgeGraphEdge(NamedTuple):
    head: str
    tail: str
    relation: str


@dataclass
class KnowledgeGraph:
    edges: set[KnowledgeGraphEdge]

    @property
    def entities(self) -> set[str]:
        return set(entity for edge in self.edges for entity in (edge.head, edge.tail))

    @property
    def relations(self) -> set[str]:
        return set(edge.relation for edge in self.edges)


class KnowledgeGraphEncoder(SetEncoder):
    """Latent Encoder to map strings to graphs.

    Graphs are represented as a set of edges currently.
    """

    def __init__(self, config: dict[str, Any]):
        super().__init__(
            config, initialize_selector_model=False, add_question_to_entailment=False
        )
        self.kg_generator = KGGen(
            model=config.get("model_name", "openai/gpt-4o-mini"),  # Default model
            temperature=config.get("temperature", 0.0),  # Default temperature
            api_key=os.environ[
                config.get("openai_api_key_environment_variable", "OPENAI_API_KEY")
            ],
            # api_key=os.environ[
            #     "OPENAI_API_KEY"
            # ],  # "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",  # Optional if set in environment or using a local model
            # make sure to disable the cache -- here it is hardcoded into the source code (very hacky) as the PyPi version does not yet support the arg
            # if you re-install, make sure to patch the caching out of this module
            # disable_cache=True,
        )
        self.cluster = config.get("cluster", False)

    def encode(self, question: str, answer: str) -> dict[Any, float]:
        graph = self.kg_generator.generate(input_data=answer, cluster=self.cluster)
        relations = {" ".join(relation): 1.0 for relation in graph.relations}
        return relations
