from typing import List, Optional, Dict, Tuple
from collections.abc import Iterable
import re
import logging
import random

from llm_graph_walk import graph, prompts, llm


logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)s %(asctime)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


class SampleSubgraph:
    def __init__(self, llm_api: llm.LLMAPI, kg_interface: graph.KGInterface):
        self.llm_api = llm_api
        self.kg = kg_interface

    def build_relation_prompt(self, query: str, node_id: int, num_retrieve: int, candidate_list: Iterable[int]) -> str:
        prompt = prompts.wiki_relation_prompt(num_retrieve)
        prompt += (
            f"Question: {query}\n"
            f"Topic Entity: {self.kg.get_node_label(node_id)}\n"
            f"Relations to choose from:"
        )
        for n, rel_id in enumerate(candidate_list):
            prompt += f"\n{n+1}. {self.kg.get_relation_label(rel_id)}"
        prompt += "\nRelevant relations: "
        return prompt
    
    def build_entity_prompt(self, query: str, candidate_node_ids: Iterable[int], relation_id: int) -> str:
        prompt = prompts.score_entity_prompt()
        prompt += (
            f"Question: {query}\n"
            f"Relation: {self.kg.get_relation_label(relation_id)}\n"
            f"Entities: {'; '.join([self.kg.get_node_label(nid) for nid in candidate_node_ids])}"
        )
        prompt += "\nScores: "
        return prompt
    
    def build_evaluate_complete_prompt(self, query: str, subgraph) -> str:
        prompt = prompts.evaluate_complete_prompt()
        prompt += (
            f"Question: {query}\n"
            f"Knowledge Triplets: "
        )
        for (h, r, t) in subgraph:
            prompt += (
                f"{self.kg.get_node_label(h)}, "
                f"{self.kg.get_relation_label(r)}, "
                f"{self.kg.get_node_label(t)}\n"
            )
        prompt += "Answer: "
        return prompt
    
    def build_answer_prompt(self, query: str, subgraph) -> str:
        prompt = prompts.rog_style_answer_prompt()
        prompt += (
            f"Question: {query}\n"
            f"Knowledge Triplets: "
        )
        for h, r, t in subgraph:
            prompt += (
                f"{self.kg.get_node_label(h)}, "
                f"{self.kg.get_relation_label(r)}, "
                f"{self.kg.get_node_label(t)}\n"
            )
        prompt += "Answer: "
        return prompt
    
    def is_complete(self, query: str, subgraph: List[int]) -> bool:
        prompt = self.build_evaluate_complete_prompt(query, subgraph)
        out = self.llm_api(prompt)
        if "yes" in out.lower():
            logger.info(
                f"Information is sufficient\n"
            )
            return True, out
        return False, out
    
    def extract_relations(self, string: str, candidate_list: Iterable[int]) -> Dict[int, float]:
        pattern = r"\s*(?P<relation>[^()]+)\s+\(Score:\s+(?P<score>[0-9.]+)\)"
        relation_scores = dict()
        for match in re.finditer(pattern, string):
            relation = match.group("relation").strip(";").strip()
            score = float(match.group("score"))
            try:
                rel_idx = [l.lower() for l in [self.kg.get_relation_label(rid) for rid in candidate_list]].index(relation.lower())
                relation_scores[candidate_list[rel_idx]] = score
            except ValueError:
                logger.warning(f"Relation '{relation}' is not among the candidates")
        return relation_scores
    
    def extract_entities(self, string: str, candidate_list: Iterable[int]) -> Dict[int, float]:
        entity_scores = dict()
        for node_id in candidate_list:
            pattern = rf"{self.kg.get_node_label(node_id).lower()}:[^\d-]*(?P<score>[0-9.]+)"
            try:
                match = re.search(pattern, string.lower())
                score = float(match.group("score").strip().strip(";"))
                entity_scores[node_id] = score
            except:
                logger.warning(
                    f"Entity '{self.kg.get_node_label(node_id)}' not found in string\n"
                )
                logger.debug(f"'{string}'")
        return entity_scores

    def sample_relations(self, query: str, node_id: int, candidate_list: Optional[Iterable[int]] = None, num_retrieve: int = 3) -> Tuple[Dict[int, float], str, str]:
        if candidate_list is None:
            candidate_list = self.kg.get_entity_relations(node_id)
        head_rels = list(candidate_list["head"].keys())
        tail_rels = list(candidate_list["tail"].keys())
        all_candidate_rels = list(set(head_rels + tail_rels))
        prompt = self.build_relation_prompt(query, node_id, num_retrieve, all_candidate_rels)
        out = self.llm_api(prompt)
        top_scored_relations = self.extract_relations(out, all_candidate_rels)
        top_scored_relations_h = {k: (v, candidate_list["head"][k]) for k, v in top_scored_relations.items() if k in head_rels}
        top_scored_relations_t = {k: (v, candidate_list["tail"][k]) for k, v in top_scored_relations.items() if k in tail_rels}
        return top_scored_relations_h, top_scored_relations_t, out, prompt

    def sample_tail_entities(self, query: str, candidate_tail_nodes: Iterable[int], relation_id: int) -> Tuple[Dict[int, float], List, str, str]:
        if len(candidate_tail_nodes) == 1:
            return {candidate_tail_nodes[0]: 1.0}, "", ""
        else:
            prompt = self.build_entity_prompt(query, candidate_tail_nodes, relation_id)
            out = self.llm_api(prompt)
            scored_entities = self.extract_entities(out, candidate_tail_nodes)
            return scored_entities, out, prompt


    def __call__(
            self,
            query: str,
            seed_node_id: int | str | Iterable[int | str],
            num_neighbours: int = 3,
            num_relations: int = 3,
            max_steps: int = 3
        ):
        subgraph_edges = []  # list of retrieved triples
        processed_nodes = set()
        if isinstance(seed_node_id, int) or isinstance(seed_node_id, str):
            node_id_queue = {seed_node_id: 1.0}
        else:
            node_id_queue = {seed_node: 1.0 for seed_node in seed_node_id}
        for step in range(1, max_steps + 1):
            top_node_id = max(node_id_queue, key=node_id_queue.__getitem__)
            processed_nodes.add(top_node_id)
            node_id_queue.pop(top_node_id)
            logger.info(f"Sampling neighbours of {self.kg.get_node_label(top_node_id)}")
            scored_relations_h, scored_relations_t, *_ = self.sample_relations(
                query,
                top_node_id,
                num_retrieve=num_relations,
            )
            overall_edges: Dict[Tuple[int,int,int], float] = dict()
            for rel, (_, neighbours) in scored_relations_h.items():
                candidate_neighbours = list(set(neighbours) - processed_nodes)
                if len(candidate_neighbours) > 30:
                    candidate_neighbours = [candidate_neighbours[i] for i in random.sample(range(len(candidate_neighbours)), 30)]
                scored_entities, *_ = self.sample_tail_entities(query, candidate_neighbours, rel)
                overall_edges = {**overall_edges, **{(top_node_id, rel, ent): score for ent, score in scored_entities.items()}}
            for rel, (_, neighbours) in scored_relations_t.items():
                candidate_neighbours = list(set(neighbours) - processed_nodes)
                if len(candidate_neighbours) > 30:
                    candidate_neighbours = [candidate_neighbours[i] for i in random.sample(range(len(candidate_neighbours)), 30)]
                scored_entities, *_ = self.sample_tail_entities(query, candidate_neighbours, rel)
                overall_edges = {**overall_edges, **{(ent, rel, top_node_id): score for ent, score in scored_entities.items()}}

            top_edges = sorted(overall_edges, key=overall_edges.__getitem__, reverse=True)[: min(num_neighbours, len(overall_edges))]

            queue_add = {t if h == top_node_id else h: overall_edges[(h, r, t)] for (h, r, t) in top_edges}
            queue_add = {e: sc for e, sc in queue_add.items() if self.kg.get_node_label(e) != e} # if neighbour is a value, do not propagate from it
            node_id_queue = {**node_id_queue, **queue_add}
            subgraph_edges.extend(top_edges)

            complete_out = self.is_complete(query, subgraph_edges)
            if complete_out[0] or len(node_id_queue.keys()) == 0:
                break
        return subgraph_edges, complete_out[1], step

        # # ToG implementation
        # subgraph_edges = []  # list of edge ids
        # processed_nodes = set()
        # if isinstance(seed_node_id, int):
        #     node_id_queue = [seed_node_id]
        # else:
        #     node_id_queue = seed_node_id
        # for step in range(1, max_steps + 1):
        #     all_rels = dict()
        #     for node_id in node_id_queue:
        #         logger.info(f"Exploring neighbours of {self.get_node_label(node_id)}")
        #         scored_relations_h, scored_relations_t, *_ = self.sample_relations(
        #             query,
        #             node_id,
        #             num_retrieve=num_relations,
        #         )
        #         for rel, (sc, neighbours) in scored_relations_h.items():
        #             all_rels[(node_id, rel)] = (sc, neighbours, "head")
        #         for rel, (sc, neighbours) in scored_relations_t.items():
        #             all_rels[(node_id, rel)] = (sc, neighbours, "tail")

        #     overall_edges = dict()
        #     for (node_id, rel), (sc, neighbours, direct) in all_rels.items():
        #         candidate_neighbours = list(set(neighbours) - processed_nodes)
        #         if len(candidate_neighbours) > 30:
        #             candidate_neighbours = [
        #                 candidate_neighbours[i]
        #                 for i in random.sample(range(len(candidate_neighbours)), 30)
        #             ]
        #         scored_entities, *_ = self.sample_tail_entities(
        #             query, candidate_neighbours, rel
        #         )
        #         overall_edges = {
        #             **overall_edges,
        #             **{(node_id, rel, ent, direct): sc*score for ent, score in scored_entities.items()}
        #         }

        #     top_edges = sorted(overall_edges, key=overall_edges.__getitem__, reverse=True)[: min(num_neighbours, len(overall_edges))]

        #     node_id_queue = list(set([e[2] if e[3] == "head" else e[0] for e in top_edges]))
        #     node_id_queue = [e for e in node_id_queue if self.get_node_label(e) != e]
        #     processed_nodes = processed_nodes.union([e[0] if e[3] == "head" else e[2] for e in top_edges])
        #     subgraph_edges.extend([edge[:3] if edge[3] == "head" else edge[:3][::-1] for edge in top_edges])

        #     complete_out = self.is_complete(query, subgraph_edges)
        #     if complete_out[0] or len(node_id_queue) == 0:
        #         break
        # return subgraph_edges, complete_out[1], step