import torch
import torch.nn as nn

from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import (
    Adj,
    OptTensor,
    PairTensor,
    SparseTensor,
    torch_sparse,
)
from torch_geometric.utils import (
    add_self_loops,
    is_torch_sparse_tensor,
    remove_self_loops,
    softmax,
)
from torch_geometric.utils.sparse import set_sparse_value

class QRGAT(MessagePassing):
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        heads: int = 1,
        concat: bool = True,
        negative_slope: float = 0.2,
        dropout: float = 0.0,
        add_self_loops: bool = True,
        edge_dim: Optional[int] = None,
        fill_value: Union[float, Tensor, str] = 'mean',
        bias: bool = True,
        share_weights: bool = False,
        residual: bool = False,
        my_gat_dropout: float = 0.1,
        emb_size_query: int = 1024,
        **kwargs,
    ):
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.edge_dim = edge_dim
        self.fill_value = fill_value
        self.residual = residual
        self.share_weights = share_weights
        self.my_gat_dropout = nn.Dropout(my_gat_dropout)
        self.norm = nn.LayerNorm(out_channels * heads)

        if isinstance(in_channels, int):
            self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,
                                            weight_initializer='glorot')
            if share_weights:
                self.lin_r = self.lin_l
            else:
                self.lin_r = Linear(in_channels, heads * out_channels,
                                                bias=bias, weight_initializer='glorot')
        else:
            self.lin_l = Linear(in_channels[0], heads * out_channels,
                                            bias=bias, weight_initializer='glorot')
            if share_weights:
                self.lin_r = self.lin_l
            else:
                self.lin_r = Linear(in_channels[1], heads * out_channels,
                                                bias=bias, weight_initializer='glorot')

        self.att = Parameter(torch.empty(1, heads, out_channels))

        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False,
                                                weight_initializer='glorot')
            self.query_edge_proj = Linear(emb_size_query, heads * out_channels, bias=False)
        else:
            self.lin_edge = None
            self.query_edge_proj = None

        total_out_channels = out_channels * (heads if concat else 1)

        if residual:
            self.res = Linear(
                in_channels
                if isinstance(in_channels, int) else in_channels[1],
                total_out_channels,
                bias=False,
                weight_initializer='glorot',
            )
        else:
            self.register_parameter('res', None)

        if bias:
            self.bias = Parameter(torch.empty(total_out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        if self.lin_edge is not None:
            self.lin_edge.reset_parameters()
            if self.query_edge_proj is not None:
                self.query_edge_proj.reset_parameters()
        if self.res is not None:
            self.res.reset_parameters()
        glorot(self.att)
        zeros(self.bias)


    def forward( 
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Adj,
        q_emb: Tensor,
        topic_entity_one_hot: Tensor,
        edge_attr: OptTensor = None,
        return_attention_weights: Optional[bool] = None,
    ) -> Union[
            Tensor,
            Tuple[Tensor, Tuple[Tensor, Tensor]],
            Tuple[Tensor, SparseTensor],
        ]:
        H, C = self.heads, self.out_channels

        res: Optional[Tensor] = None

        x_l: OptTensor = None
        x_r: OptTensor = None
        if isinstance(x, Tensor):
            assert x.dim() == 2
            q = q_emb.expand(x.size(0), -1)
            pe = topic_entity_one_hot.expand(x.size(0), -1)
            x = torch.cat([x, q, pe], dim=1)
            x = self.my_gat_dropout(x)

            if self.res is not None:
                res = self.res(x)

            x_l = self.lin_l(x).view(-1, H, C)
            if self.share_weights:
                x_r = x_l
            else:
                x_r = self.lin_r(x).view(-1, H, C)
        else:
            x_l, x_r = x[0], x[1]
            assert x[0].dim() == 2
            q_l = q_emb.expand(x_l.size(0), -1)
            pe_l = topic_entity_one_hot.expand(x_l.size(0), -1)
            x_l = torch.cat([x_l, q_l, pe_l], dim=1)
            x_l = self.my_gat_dropout(x_l)

            if x_r is not None:
                q_r = q_emb.expand(x_r.size(0), -1)
                pe_r = topic_entity_one_hot.expand(x_r.size(0), -1)
                x_r = torch.cat([x_r, q_r, pe_r], dim=1)
                x_r = self.my_gat_dropout(x_r)
                if self.res is not None:
                    res = self.res(x_r)

            x_l = self.lin_l(x_l).view(-1, H, C)
            if x_r is not None:
                x_r = self.lin_r(x_r).view(-1, H, C)

        assert x_l is not None
        assert x_r is not None

        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                num_nodes = x_l.size(0)
                if x_r is not None:
                    num_nodes = min(num_nodes, x_r.size(0))
                edge_index, edge_attr = remove_self_loops(
                    edge_index, edge_attr)
                edge_index, edge_attr = add_self_loops(
                    edge_index, edge_attr, fill_value=self.fill_value,
                    num_nodes=num_nodes)
            elif isinstance(edge_index, SparseTensor):
                if self.edge_dim is None:
                    edge_index = torch_sparse.set_diag(edge_index)
                else:
                    raise NotImplementedError(
                        "The usage of 'edge_attr' and 'add_self_loops' "
                        "simultaneously is currently not yet supported for "
                        "'edge_index' in a 'SparseTensor' form")

        alpha = self.edge_updater(edge_index, x=(x_l, x_r),
                                    edge_attr=edge_attr, q_emb=q_emb) # Pass q_emb


        out = self.propagate(edge_index, x=(x_l, x_r), alpha=alpha)

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if res is not None:
            out = out + res

        if self.bias is not None:
            out = out + self.bias

        out = self.norm(out)

        if isinstance(return_attention_weights, bool):
            if isinstance(edge_index, Tensor):
                if is_torch_sparse_tensor(edge_index):
                    adj = set_sparse_value(edge_index, alpha)
                    return out, (adj, alpha)
                else:
                    return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out


    def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor,
                    index: Tensor, ptr: OptTensor,
                    dim_size: Optional[int], q_emb: Tensor) -> Tensor: # Receive q_emb
        x = x_i + x_j

        edge_attr_projected = None
        if edge_attr is not None:
            if edge_attr.dim() == 1:
                edge_attr = edge_attr.view(-1, 1)
            assert self.lin_edge is not None
            edge_attr_projected = self.lin_edge(edge_attr)
            edge_attr_projected = edge_attr_projected.view(-1, self.heads, self.out_channels)
            x = x + edge_attr_projected 

        v = F.leaky_relu(x, self.negative_slope) 

        base_score = (v * self.att).sum(dim=-1) 

        if edge_attr_projected is not None and self.query_edge_proj is not None:
            query_proj = self.query_edge_proj(q_emb)
            query_proj = query_proj.view(1, self.heads, self.out_channels) 

            query_proj_expanded = query_proj.expand_as(edge_attr_projected)

            interaction_score = (query_proj_expanded * edge_attr_projected).sum(dim=-1) 

            alpha = base_score + interaction_score
        else:
            alpha = base_score


        alpha = softmax(alpha, index, ptr, dim_size)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return alpha


    def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
        return x_j * alpha.unsqueeze(-1)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')
        
class TripleRetriever(nn.Module):
    def __init__(
        self,
        emb_size,
        topic_pe,
        num_gat_layers=4,
        num_heads=4,
        dropout=0.1
    ):
        super().__init__()
        self.topic_pe = topic_pe
        self.emb_size = emb_size

        self.non_text_entity_emb = nn.Embedding(1, emb_size)

        self.forward_gats = nn.ModuleList()
        self.reverse_gats = nn.ModuleList()
        for _ in range(num_gat_layers):
            self.forward_gats.append(
                QRGAT(
                    in_channels=2 * emb_size + 2, 
                    out_channels=emb_size//num_heads,
                    heads=num_heads,
                    dropout=dropout,
                    edge_dim=emb_size,
                    my_gat_dropout=dropout 
                )
            )
            self.reverse_gats.append(
                QRGAT(
                    in_channels=2 * emb_size + 2, 
                    out_channels=emb_size//num_heads,
                    heads=num_heads,
                    dropout=dropout,
                    edge_dim=emb_size,
                    my_gat_dropout=dropout
                )
            )

        gat_out_dim = (emb_size//num_heads * num_heads) * 2
        pred_in_size = gat_out_dim * 2 + emb_size*2

        self.pred = nn.Sequential(
            nn.Linear(pred_in_size, emb_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(emb_size, 1)
        )

    def forward(
        self,
        h_id_tensor,
        r_id_tensor,
        t_id_tensor,
        q_emb,
        entity_embs_dict,
        num_non_text_entities,
        relation_embs_dict,
        topic_entity_one_hot
    ):
        device = next(self.parameters()).device
        if entity_embs_dict:
            entity_embs = torch.stack(list(entity_embs_dict.values()))
        else:
            entity_embs = torch.empty(0, self.emb_size, device=device)
        relation_embs = torch.stack(list(relation_embs_dict.values()))
        non_text_emb = self.non_text_entity_emb(
            torch.zeros(1, dtype=torch.long, device=device)
        ).expand(num_non_text_entities, -1)
        h_e = torch.cat([entity_embs, non_text_emb], dim=0)

        edge_index = torch.stack([h_id_tensor, t_id_tensor], dim=0)
        reverse_edge_index = torch.stack([t_id_tensor, h_id_tensor], dim=0)

        forward_x = h_e
        for gat in self.forward_gats:
            forward_x = gat(forward_x, edge_index, q_emb, topic_entity_one_hot, relation_embs[r_id_tensor])

        reverse_x = h_e
        for gat in self.reverse_gats:
            reverse_x = gat(reverse_x, reverse_edge_index, q_emb, topic_entity_one_hot, relation_embs[r_id_tensor])

        h_e = torch.cat([forward_x, reverse_x], dim=1)

        h_triple = torch.cat([
            q_emb.expand(len(r_id_tensor), -1),
            h_e[h_id_tensor],
            relation_embs[r_id_tensor],
            h_e[t_id_tensor],
        ], dim=1)
        return self.pred(h_triple)  