import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch.nn import Linear, LayerNorm
from .egnn_pyg import EGNN_Sparse
from .utils import nodeEncoder, edgeEncoder

import math


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
               OR shape [seq_len, embedding_dim]
        """
        # If x is [seq_len, embedding_dim], unsqueeze to add batch dim
        if x.dim() == 2:
            x = x.unsqueeze(1)

        # Add positional encoding
        x = x + self.pe[:x.size(0)]

        # Apply dropout and squeeze back if needed
        return self.dropout(x).squeeze(1)


class CrossAttention(nn.Module):
    def __init__(self, d_query, d_kv, d_hidden, dropout=0.1):
        super().__init__()
        self.query_proj = nn.Linear(d_query, d_hidden)
        self.key_proj = nn.Linear(d_kv, d_hidden)
        self.value_proj = nn.Linear(d_kv, d_hidden)
        self.out_proj = nn.Linear(d_hidden, d_query)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        # query: [L, d_query]
        # key, value: [L1, d_kv]
        Q = self.query_proj(query)  # [L, d_hidden]
        K = self.key_proj(key)  # [L1, d_hidden]
        V = self.value_proj(value)  # [L1, d_hidden]
        # print(K)
        # Compute attention scores
        attn_scores = Q @ K.transpose(0, 1) / (K.shape[-1] ** 0.5)  # [L, L1]
        # print(attn_scores)
        # attn_scores += torch.eye(attn_scores.shape[0]).cuda()
        attn_weights = F.softmax(attn_scores, dim=-1)  # [L, L1]
        # print(attn_weights.max(dim=-1)[1])
        attn_weights = self.dropout(attn_weights)

        # Compute attention output
        output = attn_weights @ V  # [L, d_hidden]
        # print(output-V)
        output = self.out_proj(output)  # [L, d_query]
        return output


class NormalizedResidualBlock(nn.Module):
    def __init__(
            self,
            layer: nn.Module,
            embedding_dim: int,
            dropout: float = 0.1,
    ):
        super().__init__()
        self.embedding_dim = embedding_dim

        self.layer = layer
        self.dropout_module = nn.Dropout(
            dropout,
        )
        self.layer_norm = LayerNorm(self.embedding_dim)
        self.gate = nn.Parameter(torch.zeros(1))

    def forward(self, x, *args, **kwargs):
        residual = x
        x = self.layer_norm(x)
        x = self.layer(x, *args, **kwargs)

        update = self.dropout_module(x)
        return update
        # x = residual + self.gate * update

        return x


class FeedForwardNetwork(nn.Module):
    def __init__(
            self,
            embedding_dim: int,
            ffn_embedding_dim: int,
            activation_dropout: float = 0.1,
            max_tokens_per_msa: int = 2 ** 14,
    ):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.ffn_embedding_dim = ffn_embedding_dim
        self.max_tokens_per_msa = max_tokens_per_msa
        self.activation_fn = nn.GELU()
        self.activation_dropout_module = nn.Dropout(
            activation_dropout,
        )
        self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
        self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)

    def forward(self, x):
        x = self.activation_fn(self.fc1(x))
        x = self.activation_dropout_module(x)
        x = self.fc2(x)
        return x


class EGNN_NET(torch.nn.Module):
    def __init__(self, input_feat_dim, hidden_channels, edge_attr_dim, dropout=0.0, n_layers=1, output_dim=20,
                 embedding=False, embedding_dim=64, mlp_num=2, update_edge=True, update_coors=True, norm_coors=True,
                 update_global=True, embed_ss=-1, norm_feat=False):
        super(EGNN_NET, self).__init__()
        self.dropout = dropout

        self.update_edge = update_edge
        self.norm_coors = norm_coors
        self.update_coors = update_coors
        self.update_global = update_global
        self.mpnn_layes = nn.ModuleList()
        self.time_mlp_list = nn.ModuleList()
        self.ff_list = nn.ModuleList()

        self.embedding = embedding
        self.embed_ss = embed_ss
        self.n_layers = n_layers
        if embedding:
            self.time_mlp = nn.Sequential(nn.Linear(1, hidden_channels), nn.SiLU(),
                                          nn.Linear(hidden_channels, embedding_dim))

            self.ss_mlp = nn.Sequential(nn.Linear(8, hidden_channels), nn.SiLU(),
                                        nn.Linear(hidden_channels, embedding_dim))
        else:
            self.time_mlp = nn.Sequential(nn.Linear(1, hidden_channels), nn.SiLU(),
                                          nn.Linear(hidden_channels, input_feat_dim))

            self.ss_mlp = nn.Sequential(nn.Linear(8, hidden_channels), nn.SiLU(),
                                        nn.Linear(hidden_channels, input_feat_dim))

        for i in range(n_layers):
            if embedding:
                layer = EGNN_Sparse(embedding_dim, m_dim=hidden_channels, edge_attr_dim=embedding_dim, dropout=dropout,
                                    mlp_num=mlp_num, update_edge=self.update_edge, norm_coors=self.norm_coors,
                                    update_coors=self.update_coors, update_global=self.update_global,
                                    norm_feats=norm_feat)
            else:
                layer = EGNN_Sparse(input_feat_dim, m_dim=hidden_channels, edge_attr_dim=edge_attr_dim, dropout=dropout,
                                    mlp_num=mlp_num, update_edge=self.update_edge, norm_coors=self.norm_coors,
                                    update_coors=self.update_coors, update_global=self.update_global,
                                    norm_feats=norm_feat)
            self.mpnn_layes.append(layer)

            if embedding:
                time_mlp_layer = nn.Sequential(nn.SiLU(), nn.Linear(embedding_dim, (embedding_dim) * 2))
                ff_layer = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.Dropout(p=dropout), nn.SiLU(),
                                         torch_geometric.nn.norm.LayerNorm(embedding_dim),
                                         nn.Linear(embedding_dim, embedding_dim))
            else:
                time_mlp_layer = nn.Sequential(nn.SiLU(), nn.Linear(input_feat_dim, (input_feat_dim) * 2))
                ff_layer = nn.Sequential(nn.Linear(input_feat_dim, input_feat_dim), nn.Dropout(p=dropout), nn.SiLU(),
                                         torch_geometric.nn.norm.LayerNorm(input_feat_dim),
                                         nn.Linear(input_feat_dim, input_feat_dim))

            self.time_mlp_list.append(time_mlp_layer)
            self.ff_list.append(ff_layer)

        if embedding:
            self.node_embedding = nodeEncoder(embedding_dim)
            self.edge_embedding = edgeEncoder(embedding_dim)
            self.lin = Linear(embedding_dim, output_dim)
        else:
            self.lin = Linear(input_feat_dim, output_dim)

        self.rag_mlp = nn.Sequential(nn.Linear(20, hidden_channels), nn.ReLU(),
                                     nn.Linear(hidden_channels, embedding_dim))
        self.rag_mlp2 = nn.Sequential(nn.Linear(embedding_dim, hidden_channels), nn.ReLU(),
                                      nn.Linear(hidden_channels, embedding_dim))

    def forward(self, data, time):
        x, pos, extra_x, edge_index, edge_attr, ss, batch = data.x, data.pos, data.extra_x, data.edge_index, data.edge_attr, data.ss, data.batch
        t = self.time_mlp(time)

        ss_embed = self.ss_mlp(ss)

        x = torch.cat([x, extra_x], dim=1)
        if self.embedding:
            x = self.node_embedding(x)
            edge_attr = self.edge_embedding(edge_attr)

        if self.embed_ss == -3:
            x = x + ss_embed

        x = torch.cat([pos, x], dim=1)

        for i, layer in enumerate(self.mpnn_layes):

            if self.embed_ss == -2 and i == self.n_layers - 1:
                corr, feats = x[:, 0:3], x[:, 3:]
                feats = feats + ss_embed  # [N,hidden_dim]+[N,hidden_dim]
                x = torch.cat([corr, feats], dim=-1)

            if self.update_edge:
                h, edge_attr = layer(x, edge_index, edge_attr, batch)  # [N,hidden_dim]
            else:
                h = layer(x, edge_index, edge_attr, batch)  # [N,hidden_dim]

            corr, feats = h[:, 0:3], h[:, 3:]

            time_emb = self.time_mlp_list[i](t)  # [B,hidden_dim*2]
            scale_, shift_ = time_emb.chunk(2, dim=1)
            scale = scale_[data.batch]
            shift = shift_[data.batch]
            feats = feats * (scale + 1) + shift

            feats = self.ff_list[i](feats)

            x = torch.cat([corr, feats], dim=-1)

        corr, x = x[:, 0:3], x[:, 3:]

        # RAG
        rag_input = data.rag_embed
        rag_input = self.rag_mlp(rag_input)
        x += rag_input
        x = self.rag_mlp2(x)

        if self.embed_ss == -1:
            x = x + ss_embed

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)
        return x
