import pdb
import math
import os
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch_geometric.utils import degree
from torch_sparse import SparseTensor, matmul


class GraphConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, use_weight=True, use_init=False):
        super(GraphConvLayer, self).__init__()

        self.use_init = use_init
        self.use_weight = use_weight
        if self.use_init:
            in_channels_ = 2 * in_channels
        else:
            in_channels_ = in_channels
        self.W = nn.Linear(in_channels_, out_channels)

    def reset_parameters(self):
        self.W.reset_parameters()

    def forward(self, x, edge_index, x0):
        N = x.shape[0]
        row, col = edge_index
        row = row.long()
        col = col.long()
        d = degree(col, N).float()
        d_norm_in = (1. / d[col]).sqrt()
        d_norm_out = (1. / d[row]).sqrt()
        value = torch.ones_like(row) * d_norm_in * d_norm_out
        value = torch.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
        adj = SparseTensor(row=col, col=row, value=value, sparse_sizes=(N, N))
        x = matmul(adj, x)  # [N, D]

        if self.use_init:
            x = torch.cat([x, x0], 1)
            x = self.W(x)
        elif self.use_weight:
            x = self.W(x)
        return x


class GraphConv(nn.Module):
    def __init__(self, in_channels, hidden_channels, gnn_dropout, gnn_num_layers, gnn_use_bn, gnn_use_residual,
                    gnn_use_weight, gnn_use_init, gnn_use_act):
        super(GraphConv, self).__init__()
        self.dropout = gnn_dropout
        self.num_layers = gnn_num_layers
        self.use_bn = gnn_use_bn
        self.use_residual = gnn_use_residual
        self.use_weight = gnn_use_weight
        self.use_init = gnn_use_init
        self.use_act = gnn_use_act

        # num_layers = args., dropout = 0.5, use_bn = True, use_residual = True,
        # use_weight = True, use_init = False, use_act = True
        self.convs = nn.ModuleList()
        self.fcs = nn.ModuleList()
        self.fcs.append(nn.Linear(in_channels, hidden_channels))

        self.bns = nn.ModuleList()
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        for _ in range(self.num_layers):
            self.convs.append(
                GraphConvLayer(hidden_channels, hidden_channels, self.use_weight, self.use_init))
            self.bns.append(nn.BatchNorm1d(hidden_channels))

        self.dropout = self.dropout
        self.activation = F.relu
        self.use_bn = self.use_bn
        self.use_residual = self.use_residual
        self.use_act = self.use_act
        self.reset_parameters()

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        for fc in self.fcs:
            fc.reset_parameters()

    def forward(self, x, edge_index):
        layer_ = []

        x = self.fcs[0](x)
        if self.use_bn:
            x = self.bns[0](x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        layer_.append(x)

        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index, layer_[0])
            if self.use_bn:
                x = self.bns[i + 1](x)
            if self.use_act:
                x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if self.use_residual:
                x = x + layer_[-1]
        return x

class HypLinear(nn.Module):
    """
    Hyperbolic Linear Layer

    Parameters:
        manifold (Manifold): The manifold to use for the linear transformation.
        in_features (int): The size of each input sample.
        out_features (int): The size of each output sample.
        bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True.
        dropout (float, optional): The dropout probability. Default is 0.0.
        manifold_out (Manifold, optional): The output manifold. Default is None.
    """

    def __init__(self, manifold, in_features, out_features, c_in, c_out, bias=True, dropout=0.0):
        super().__init__()
        self.in_features = in_features +    1  # +1 for time dimension
        self.out_features = out_features
        self.bias = bias
        self.manifold = manifold
        self.c_in = c_in
        self.c_out = c_out

        self.linear = nn.Linear(self.in_features, self.out_features, bias=bias)
        self.dropout_rate = dropout
        self.reset_parameters()

    def reset_parameters(self):
        """Reset layer parameters."""
        init.xavier_uniform_(self.linear.weight, gain=math.sqrt(2))
        if self.bias:
            init.constant_(self.linear.bias, 0)

    def forward(self, x, x_manifold='hyp'):
        """Forward pass for hyperbolic linear layer."""
        if x_manifold != 'hyp':
            x = torch.cat([torch.ones_like(x)[..., 0:1], x], dim=-1)
            x = self.manifold.expmap0(x, self.c_in)
        x_space = self.linear(x)

        x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.c_in).sqrt()
        x = torch.cat([x_time, x_space], dim=-1)
        if self.c_in != self.c_out:
            x = x * (self.c_in / self.c_out).sqrt()
        return x

class HypLayerNorm(nn.Module):
    """
    Hyperbolic Layer Normalization Layer

    Parameters:
        manifold (Manifold): The manifold to use for normalization.
        in_features (int): The number of input features.
        manifold_out (Manifold, optional): The output manifold. Default is None.
    """

    def __init__(self, manifold, in_features, c_in, c_out):
        super(HypLayerNorm, self).__init__()
        self.in_features = in_features
        self.manifold = manifold
        self.c_in = c_in
        self.c_out = c_out
        self.layer = nn.LayerNorm(self.in_features)
        self.reset_parameters()

    def reset_parameters(self):
        """Reset layer parameters."""
        self.layer.reset_parameters()

    def forward(self, x):
        """Forward pass for hyperbolic layer normalization."""
        x_space = x[..., 1:]
        x_space = self.layer(x_space)
        x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.c_in).sqrt()
        x = torch.cat([x_time, x_space], dim=-1)

        if self.c_in != self.c_out:
            x = x * (self.c_in / self.c_out).sqrt()
        return x

class HypDropout(nn.Module):
    """
    Hyperbolic Dropout Layer

    Parameters:
        manifold (Manifold): The manifold to use for the dropout.
        dropout (float): The dropout probability.
        manifold_out (Manifold, optional): The output manifold. Default is None.
    """

    def __init__(self, manifold, dropout, c_in, c_out):
        super(HypDropout, self).__init__()
        self.manifold = manifold
        self.c_in = c_in
        self.c_out = c_out
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, training=False):
        """Forward pass for hyperbolic dropout."""
        if training:
            x_space = x[..., 1:]
            x_space = self.dropout(x_space)
            x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.c_in).sqrt()
            x = torch.cat([x_time, x_space], dim=-1)
            if self.c_in != self.c_out:
                x = x * (self.c_in / self.c_out).sqrt()
        return x

class HypActivation(nn.Module):
    """
    Hyperbolic Activation Layer

    Parameters:
        manifold (Manifold): The manifold to use for the activation.
        activation (function): The activation function.
        manifold_out (Manifold, optional): The output manifold. Default is None.
    """

    def __init__(self, manifold, activation, c_in, c_out):
        super(HypActivation, self).__init__()
        self.manifold = manifold
        self.c_in = c_in
        self.c_out = c_out
        self.activation = activation

    def forward(self, x):
        """Forward pass for hyperbolic activation."""
        x_space = x[..., 1:]
        x_space = self.activation(x_space)
        x_time = ((x_space ** 2).sum(dim=-1, keepdims=True) + self.c_in).sqrt()
        x = torch.cat([x_time, x_space], dim=-1)
        if self.c_in != self.c_out:
            x = x * (self.c_in / self.c_out).sqrt()
        return x


class TransConvLayer(nn.Module):
    def __init__(self, manifold, in_channels, out_channels, num_heads, c_in, c_out, use_weight=True, args=None):
        super().__init__()
        self.manifold = manifold
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.c_in = c_in
        self.c_out = c_out
        self.num_heads = num_heads
        self.use_weight = use_weight
        self.attention_type = args.attention_type

        self.Wk = nn.ModuleList()
        self.Wq = nn.ModuleList()
        for i in range(self.num_heads):
            self.Wk.append(HypLinear(self.manifold, self.in_channels, self.out_channels, self.c_in, self.c_out))
            self.Wq.append(HypLinear(self.manifold, self.in_channels, self.out_channels, self.c_in, self.c_out))

        if use_weight:
            self.Wv = nn.ModuleList()
            for i in range(self.num_heads):
                self.Wv.append(HypLinear(self.manifold, self.in_channels, self.out_channels, self.c_in, self.c_out))

        self.scale = nn.Parameter(torch.tensor([math.sqrt(out_channels)]))
        self.bias = nn.Parameter(torch.zeros(()))
        self.norm_scale = nn.Parameter(torch.ones(()))
        self.v_map_mlp = nn.Linear(in_channels, out_channels, bias=True)
        self.final_layer = nn.Linear(self.num_heads * self.out_channels, self.out_channels, bias=True)
        self.power_k = args.power_k
        self.trans_heads_concat = args.trans_heads_concat


    @staticmethod
    def fp(x, p=2):
        norm_x = torch.norm(x, p=2, dim=-1, keepdim=True)
        norm_x_p = torch.norm(x ** p, p=2, dim=-1, keepdim=True)
        return (norm_x / norm_x_p) * x ** p

    def full_attention(self, qs, ks, vs, output_attn=False):
        # normalize input
        # qs = HypNormalization(self.manifold)(qs)
        # ks = HypNormalization(self.manifold)(ks)

        # negative squared distance (less than 0)
        att_weight = 2 + 2 * self.manifold.cinner(qs.transpose(0, 1), ks.transpose(0, 1))  # [H, N, N]
        att_weight = att_weight / self.scale + self.bias  # [H, N, N]

        att_weight = nn.Softmax(dim=-1)(att_weight)  # [H, N, N]
        att_output = self.manifold.mid_point(vs.transpose(0, 1), att_weight)  # [N, H, D]
        att_output = att_output.transpose(0, 1)  # [N, H, D]

        att_output = self.manifold.mid_point(att_output, self.c_out)
        if output_attn:
            return att_output, att_weight
        else:
            return att_output

    def linear_focus_attention(self, hyp_qs, hyp_ks, hyp_vs, output_attn=False):
        qs = hyp_qs[..., 1:]
        ks = hyp_ks[..., 1:]
        v = hyp_vs[..., 1:]
        phi_qs = (F.relu(qs) + 1e-6) / (self.norm_scale.abs() + 1e-6)  # [N, H, D]
        phi_ks = (F.relu(ks) + 1e-6) / (self.norm_scale.abs() + 1e-6)  # [N, H, D]

        phi_qs = self.fp(phi_qs, p=self.power_k)  # [N, H, D]
        phi_ks = self.fp(phi_ks, p=self.power_k)  # [N, H, D]

        # Step 1: Compute the kernel-transformed sum of K^T V across all N for each head
        k_transpose_v = torch.einsum('nhm,nhd->hmd', phi_ks, v)  # [H, D, D]

        # Step 2: Compute the kernel-transformed dot product of Q with the above result
        numerator = torch.einsum('nhm,hmd->nhd', phi_qs, k_transpose_v)  # [N, H, D]

        # Step 3: Compute the normalizing factor as the kernel-transformed sum of K
        denominator = torch.einsum('nhd,hd->nh', phi_qs, torch.einsum('nhd->hd', phi_ks))  # [N, H]
        denominator = denominator.unsqueeze(-1)  #

        # Step 4: Normalize the numerator with the denominator
        attn_output = numerator / (denominator + 1e-6)  # [N, H, D]

        # Map vs through v_map_mlp and ensure it is the correct shape
        vss = self.v_map_mlp(v)  # [N, H, D]
        attn_output = attn_output + vss  # preserve its rank, [N, H, D]

        if self.trans_heads_concat:
            attn_output = self.final_layer(attn_output.reshape(-1, self.num_heads * self.out_channels))
        else:
            attn_output = attn_output.mean(dim=1)

        attn_output_time = ((attn_output ** 2).sum(dim=-1, keepdims=True) + self.c_out) ** 0.5
        attn_output = torch.cat([attn_output_time, attn_output], dim=-1)

        if output_attn:
            return attn_output, attn_output
        else:
            return attn_output

    def forward(self, query_input, source_input, edge_index=None, edge_weight=None, output_attn=False):
        # feature transformation
        q_list = []
        k_list = []
        v_list = []
        for i in range(self.num_heads):
            q_list.append(self.Wq[i](query_input))
            k_list.append(self.Wk[i](source_input))
            if self.use_weight:
                v_list.append(self.Wv[i](source_input))
            else:
                v_list.append(source_input)

        query = torch.stack(q_list, dim=1)  # [N, H, D]
        key = torch.stack(k_list, dim=1)  # [N, H, D]
        value = torch.stack(v_list, dim=1)  # [N, H, D]

        if output_attn:
            if self.attention_type == 'linear_focused':
                attention_output, attn = self.linear_focus_attention(
                    query, key, value, output_attn)  # [N, H, D]
            elif self.attention_type == 'full':
                attention_output, attn = self.full_attention(
                    query, key, value, output_attn)
            else:
                raise NotImplementedError
        else:
            if self.attention_type == 'linear_focused':
                attention_output = self.linear_focus_attention(
                    query, key, value)  # [N, H, D]
            elif self.attention_type == 'full':
                attention_output = self.full_attention(
                    query, key, value)
            else:
                raise NotImplementedError


        final_output = attention_output
        # multi-head attention aggregation
        # final_output = self.manifold.mid_point(final_output)

        if output_attn:
            return final_output, attn
        else:
            return final_output


class TransConv(nn.Module):
    def __init__(self, manifold, c_in, c_hidden, c_out, in_channels, hidden_channels, act, args=None):
        super().__init__()
        self.manifold = manifold
        self.c_in = c_in
        self.c_hidden = c_hidden
        self.c_out = c_out
        
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_layers = args.trans_num_layers
        self.num_heads = args.trans_num_heads
        self.dropout_rate = args.dropout
        self.use_bn = args.trans_use_bn
        self.residual = args.trans_use_residual
        self.use_act = args.trans_use_act
        self.use_weight = True # Use matrix V

        self.convs = nn.ModuleList()
        self.fcs = nn.ModuleList()
        self.bns = nn.ModuleList()

        self.fcs.append(HypLinear(self.manifold, self.in_channels, self.hidden_channels, c_in, c_hidden))
        self.bns.append(HypLayerNorm(self.manifold, self.hidden_channels, c_hidden, c_hidden))

        self.add_pos_enc = args.add_positional_encoding
        self.positional_encoding = HypLinear(self.manifold, self.in_channels, self.hidden_channels, c_in, c_hidden)
        self.epsilon = torch.tensor([1.0], device=args.device)

        for i in range(self.num_layers):
            self.convs.append(
                TransConvLayer(self.manifold, self.hidden_channels, self.hidden_channels, self.num_heads, c_hidden, c_hidden,
                              use_weight=self.use_weight, args=args))
            self.bns.append(HypLayerNorm(self.manifold, self.hidden_channels, c_hidden, c_hidden))

        self.dropout = HypDropout(self.manifold, self.dropout_rate, c_hidden, c_hidden)
        self.activation = HypActivation(self.manifold, act, c_hidden, c_hidden)

        self.fcs.append(HypLinear(self.manifold, self.hidden_channels, self.hidden_channels, c_hidden, c_out))

    def forward(self, x_input):
        layer_ = []

        # the original inputs are in Euclidean
        x = self.fcs[0](x_input, x_manifold='euc')
        # add positional embedding
        if self.add_pos_enc:
            x_pos = self.positional_encoding(x_input, x_manifold='euc')
            x = self.manifold.mid_point(torch.stack((x, self.epsilon*x_pos), dim=1), self.c_hidden)

        if self.use_bn:
            x = self.bns[0](x)
        if self.use_act:
            x = self.activation(x)
        x = self.dropout(x, training=self.training)
        layer_.append(x)

        for i, conv in enumerate(self.convs):
            x = conv(x, x)
            if self.residual:
                x = self.manifold.mid_point(torch.stack((x, layer_[i]), dim=1), self.c_hidden)
            if self.use_bn:
                x = self.bns[i + 1](x)
            # if self.use_act:
            #     x = self.activation(x)
            # # x = self.dropout(x, training=self.training)
            layer_.append(x)

        x = self.fcs[-1](x)
        return x

