import logging
import random
from collections.abc import Iterable
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch

from llm_graph_walk import graph, text_encoder

END_REL = "END OF HOP"

logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)s %(asctime)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


class SampleSubgraphSR:
    def __init__(
        self, kg_interface: graph.KGInterface, text_encoder: text_encoder.TextEncoder
    ):
        self.kg = kg_interface
        self.text_encoder = text_encoder

    def score_path_list_and_relation_list(
        self,
        question: str,
        path_list: List[List[str]],
        path_score_list: List[float],
        relation_list_list: List[List[str]],
        theta: float = 0.07,
    ) -> List[Tuple[List[str], float]]:

        results = []
        query_lined_list = [
            "#".join([question] + [self.kg.get_relation_label(r) for r in path])
            for path in path_list
        ]
        all_relation_list = list(set(sum(relation_list_list, [])))
        q_emb = self.text_encoder.embed(query_lined_list).unsqueeze(1)  # [B, 1, D]
        target_emb = self.text_encoder.embed(
            [self.kg.get_relation_label(x) for x in all_relation_list]
        ).unsqueeze(
            0
        )  # [1, L, D]
        sim_score = torch.cosine_similarity(q_emb, target_emb, dim=2) / theta  # [B, L]
        for i, (path, path_score, relation_list) in enumerate(
            zip(path_list, path_score_list, relation_list_list)
        ):
            for relation in relation_list:
                j = all_relation_list.index(relation)
                new_path = path + [relation]
                score = float(sim_score[i, j]) + path_score
                results.append((i, new_path, score))
        return results

    def sample_relations(
        self,
        query: str,
        node_id: int,
        candidate_list: Optional[Iterable[int]] = None,
        num_retrieve: int = 3,
    ):
        raise NotImplementedError

    def sample_tail_entities(
        self, query: str, candidate_tail_nodes: Iterable[int], relation_id: int
    ) -> Tuple[Dict[int, float], List, str, str]:
        raise NotImplementedError

    def sample_metapaths(
        self,
        query,
        topic_entity,
        num_neighbours: int = 100,
        max_hop: int = 3,
        num_return_paths: int = 10,
        num_beams: int = 10,
    ):
        candidate_paths = [[[], 0, [topic_entity]]]  # list of: path, score and leaves
        result_paths = []
        n_hop = 0

        while (
            candidate_paths and len(result_paths) < num_return_paths and n_hop < max_hop
        ):
            search_list = []
            # try every possible next_hop
            relation_list_list = []
            path_list = []
            path_score_list = []
            # logger.info(f'candidate_paths: {candidate_paths}')
            for path, path_score, path_leaves in candidate_paths:
                path_list.append(path)
                path_score_list.append(path_score)
                # logger.info(f'path_to_candidate_relations: {topic_entity}, {path}')
                candidate_relations = []
                for leaf in path_leaves:
                    leaf_neigh = self.kg.get_entity_relations(leaf)
                    candidate_relations.extend(list(leaf_neigh["head"].keys()))
                    candidate_relations.extend(list(leaf_neigh["tail"].keys()))
                candidate_relations = list(set(candidate_relations)) + [
                    self.kg.knowledge_graph.num_total_relations - 1
                ]  # append END_REL
                relation_list_list.append(candidate_relations)
            search_list = self.score_path_list_and_relation_list(
                query, path_list, path_score_list, relation_list_list
            )

            # search_list = sorted(search_list, key=lambda x: x[-1], reverse=True)[:num_beams]
            # Update candidate_paths and result_paths
            n_hop = n_hop + 1
            new_candidate_paths = []
            for path_source, path, score in sorted(
                search_list, key=lambda x: x[-1], reverse=True
            )[:num_beams]:
                # if len(new_candidate_paths) >= num_beams:
                #     break
                if path[-1] == self.kg.knowledge_graph.num_total_relations - 1:
                    result_paths.append([path, score])
                else:
                    new_current_leaves = []
                    for leaf in candidate_paths[path_source][-1]:
                        neigh_dict = self.kg.get_entity_relations(leaf)
                        new_current_leaves.extend(
                            neigh_dict["head"].get(path[-1], [])
                            + neigh_dict["tail"].get(path[-1], [])
                        )
                    new_current_leaves = list(set(new_current_leaves))
                    if len(new_current_leaves) > num_neighbours:
                        new_current_leaves = random.sample(
                            new_current_leaves, num_neighbours
                        )
                    new_candidate_paths.append([path, score, new_current_leaves])
            candidate_paths = new_candidate_paths

        # Force early stop
        if n_hop == max_hop and candidate_paths:
            for path, score, _ in candidate_paths:
                path = path + [self.kg.knowledge_graph.num_total_relations - 1]
                result_paths.append([path, score])
        result_paths = sorted(result_paths, key=lambda x: x[1], reverse=True)[
            :num_return_paths
        ]

        return result_paths

    def __call__(
        self,
        query: str,
        seed_node_id: int | str | Iterable[int | str],
        num_neighbours: int = 10,
        num_relations: int = 10,
        max_hop: int = 3,
        num_return_paths: int = 6,
        num_beams: int = 5,
        max_subgraph_size: int = 500,
    ):
        if isinstance(seed_node_id, int) or isinstance(seed_node_id, str):
            seed_node_id = [seed_node_id]

        total_triples = []
        total_paths = []
        for topic_entity in seed_node_id:
            # sample metapaths
            result_paths = self.sample_metapaths(
                query,
                topic_entity,
                10 * num_neighbours,
                max_hop,
                num_return_paths,
                num_beams,
            )
            total_paths.extend([[topic_entity] + p for p in result_paths])

        # recover triples from paths
        for topic_entity, path, _ in sorted(
            total_paths, key=lambda x: x[2] / len(x[1]), reverse=True
        ):
            nodes, triples = set(), set()
            hop_nodes, next_hop_set = set(), set()
            hop_nodes.add(topic_entity)
            nodes.add(topic_entity)
            for relation in path:
                next_hop_set = set()
                if relation == self.kg.knowledge_graph.num_total_relations - 1:
                    continue
                for node in hop_nodes:
                    neigh = self.kg.get_entity_relations(node)
                    tail_set = set(
                        neigh["head"].get(relation, [])
                        + neigh["tail"].get(relation, [])
                    )
                    if len(tail_set) > num_neighbours:
                        tail_set = set(random.sample(list(tail_set), num_neighbours))
                    next_hop_set |= tail_set
                    triples |= {(node, relation, tail) for tail in tail_set}
                hop_nodes = next_hop_set
                nodes = nodes | hop_nodes
                if len(list(nodes)) > 1000:
                    break
            total_triples.extend(list(triples))
            total_triples_arr = np.array(total_triples)
            total_triples = total_triples_arr[
                np.sort(np.unique(total_triples_arr, axis=0, return_index=True)[1])
            ].tolist()

            if len(total_triples) > max_subgraph_size:
                total_triples = total_triples[:max_subgraph_size]
                break

        return total_triples, result_paths
