import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

from llm_graph_walk import graph, text_encoder

from collections.abc import Iterable


class PEConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="mean")

    def forward(self, edge_index, x):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j


class DDE(nn.Module):
    def __init__(self, num_rounds, num_reverse_rounds):
        super().__init__()

        self.layers = nn.ModuleList()
        for _ in range(num_rounds):
            self.layers.append(PEConv())

        self.reverse_layers = nn.ModuleList()
        for _ in range(num_reverse_rounds):
            self.reverse_layers.append(PEConv())

    def forward(self, topic_entity_one_hot, edge_index, reverse_edge_index):
        result_list = []

        h_pe = topic_entity_one_hot
        for layer in self.layers:
            h_pe = layer(edge_index, h_pe)
            result_list.append(h_pe)

        h_pe_rev = topic_entity_one_hot
        for layer in self.reverse_layers:
            h_pe_rev = layer(reverse_edge_index, h_pe_rev)
            result_list.append(h_pe_rev)

        return result_list


class SampleRelations(torch.nn.Module):
    def __init__(
        self,
        text_encoder: text_encoder.TextEncoder,
        rel_labels: list[str],
        rel_edge_id: torch.Tensor,
        topic_pe=True,
        dde_num_rounds=2,
        dde_num_reverse_rounds=2,
        n_metap=0,
    ):
        super().__init__()

        self.n_metap = n_metap
        self.rel_edges = rel_edge_id
        self.text_encoder = text_encoder
        for param in self.text_encoder.model.parameters():
            param.requires_grad = False
        self.relation_embs = self.text_encoder.embed(rel_labels)
        emb_size = self.text_encoder.emb_size
        self.topic_pe = topic_pe
        self.dde = DDE(dde_num_rounds, dde_num_reverse_rounds)

        # pred_in_size = 3 * emb_size
        # if topic_pe:
        #     pred_in_size += 2
        # pred_in_size += 2 * (dde_num_rounds + dde_num_reverse_rounds)

        pred_in_size = 4 * emb_size
        if topic_pe:
            pred_in_size += 2 * 2
        pred_in_size += 2 * 2 * (dde_num_rounds + dde_num_reverse_rounds)

        self.pred = nn.Sequential(
            nn.Linear(pred_in_size, emb_size), nn.ReLU(), nn.Linear(emb_size, 1)
        )

    def forward(self, question, q_id_label, neigh_r_onehot):
        q_emb = self.text_encoder.embed(question)
        entity_embs = self.text_encoder.embed(q_id_label)

        relation_entity_one_hot = F.one_hot(neigh_r_onehot.long(), num_classes=2).to(
            entity_embs.device
        )

        h_e_list = [self.relation_embs]
        if self.topic_pe:
            h_e_list.append(relation_entity_one_hot)

        dde_list = self.dde(
            relation_entity_one_hot,
            self.rel_edges.to(entity_embs.device),
            self.rel_edges[[1, 0]].to(entity_embs.device),
        )

        h_e_list.extend(dde_list)
        h_e = torch.cat(h_e_list, dim=1)

        # h_triple = torch.cat(
        #     [
        #         q_emb.expand(len(self.relation_embs), -1),
        #         entity_embs.expand(len(self.relation_embs), -1),
        #         h_e,
        #     ],
        #     dim=1,
        # )

        h_triple = torch.cat(
            [
                q_emb.expand(self.rel_edges.shape[1], -1),
                entity_embs.expand(self.rel_edges.shape[1], -1),
                h_e[self.rel_edges[0]],
                h_e[self.rel_edges[1]],
            ],
            dim=1,
        )

        pred_triple_logits = self.pred(h_triple)

        if self.n_metap > 0:
            # rel_ids = torch.topk(
            #     pred_triple_logits.flatten(),
            #     k=min(self.n_metap, pred_triple_logits.shape[0]),
            # ).indices
            # return pred_triple_logits, rel_ids
            metap_ids = torch.topk(
                pred_triple_logits.flatten(),
                k=min(self.n_metap, pred_triple_logits.shape[0]),
            ).indices
            top_subgraph_edges = self.rel_edges[:, metap_ids].T.numpy().tolist()
            # subgraph_edges_h = self.rel_edges[0, metap_ids].numpy().tolist()
            # subgraph_edges_t = self.rel_edges[1, metap_ids].numpy().tolist()
            # top_subgraph_edges = list(zip(subgraph_edges_h, subgraph_edges_t))
            return pred_triple_logits, top_subgraph_edges
        else:
            return pred_triple_logits, []


class SampleSubgraphFromMP:
    def __init__(
        self,
        kg_interface: graph.KGInterface,
        metap_sampler: SampleRelations,
    ):
        self.kg = kg_interface
        self.metap_sampler = metap_sampler

    def __call__(
        self,
        query: str,
        seed_node_id: int | str,
        num_return_paths: int = 10,
        max_subgraph_size: int = 500,
    ):
        if isinstance(seed_node_id, int) or isinstance(seed_node_id, str):
            seed_node_id = [seed_node_id]

        retrieved_edges = []
        for seed in seed_node_id:
            init_scores = torch.zeros(self.kg.knowledge_graph.num_total_relations)
            neigh_dict = self.kg.knowledge_graph.neighbourhood_dict[seed]
            all_rels = torch.concat(
                [
                    torch.tensor(list(neigh_dict["head"].keys())),
                    torch.tensor(list(neigh_dict["tail"].keys())),
                ]
            ).long()
            init_scores[all_rels] = 1.0
            mp_logits, _ = self.metap_sampler(
                query, self.kg.get_node_label(seed), init_scores
            )
            metap_ids = torch.topk(
                mp_logits.flatten(),
                k=min(num_return_paths, mp_logits.shape[0]),
            ).indices.cpu()
            top_mps = self.metap_sampler.rel_edges[:, metap_ids].T.numpy().tolist()

            for mp in top_mps:
                if mp[0] == mp[1]:
                    mp = [mp[0]]
                queue = [[seed, []]]
                while len(queue) > 0 and len(retrieved_edges) < max_subgraph_size:
                    curr, path = queue.pop(0)
                    if len(path) == len(mp):
                        retrieved_edges.extend(path)
                    else:
                        r = mp[len(path)]
                        for t in self.kg.get_entity_relations(curr)["head"].get(r, []):
                            queue.append([t, path + [[curr, r, t]]])
                        for h in self.kg.get_entity_relations(curr)["tail"].get(r, []):
                            queue.append([h, path + [[h, r, curr]]])

        return retrieved_edges, top_mps
