import math
import sys
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional

import torch
import torch.nn as nn
from torch import Tensor
from torchvision.models.vision_transformer import MLPBlock

from pyprojroot import here as project_root

sys.path.insert(0, str(project_root()))

from modules.gnn import GNNConfig
from modules.graph_readout import GraphReadoutConfig
from modules.graph_feature_extractor import GraphFeatureExtractorConfig, GraphFeatureExtractor

class PositionalEncoding(nn.Module):
  """Standard Transformer positional encoding (sin + cos)"""

  def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout)

    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe = torch.zeros(max_len, 1, d_model)
    pe[:, 0, 0::2] = torch.sin(position * div_term)
    pe[:, 0, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe)

  def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
    """
    Input: [batch_size, seq_length, hidden_dim]

    Pe: [seq_len, batch_size, hidden_dim].

    Output: [batch_size, seq_length, hidden_dim]
    """
    x = torch.transpose(x, 0, 1)
    x = x + self.pe[:x.size(0)]
    return torch.transpose(self.dropout(x), 0, 1)


class MPNNFeatureExtractor(nn.Module):
  """Encode the local topology around an atom with an MPNN."""

  def __init__(self, atom_dim, hidden_dim, num_heads):
    super().__init__()
    gnn_config = GNNConfig(type='PNA', num_edge_types=3, hidden_dim=128, num_heads=4,
                           per_head_dim=64, intermediate_dim=1024, message_function_depth=1,
                           num_layers=10)
    gnn_readout = GraphReadoutConfig(readout_type='combined', use_all_states=True, num_heads=12, head_dim=64,
                                     output_dim=hidden_dim)
    gfe_config = GraphFeatureExtractorConfig(initial_node_feature_dim=atom_dim, gnn_config=gnn_config,
                                             readout_config=gnn_readout, output_norm='off')
    self.gfe = GraphFeatureExtractor(gfe_config)

  def forward(self, x, *args, **kwargs) -> torch.Tensor:
    return self.gfe(x)

class Encoder(nn.Module):
  """Transformer Model Encoder for sequence to sequence translation."""

  def __init__(
          self,
          num_layers: int,
          num_heads: int,
          hidden_dim: int,
          mlp_dim: int,
          dropout: float,
          attention_dropout: float,
          norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  ):
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    layers: OrderedDict[str, nn.Module] = OrderedDict()
    for i in range(num_layers):
      layers[f"encoder_layer_{i}"] = EncoderBlock(
        num_heads,
        hidden_dim,
        mlp_dim,
        dropout,
        attention_dropout,
        norm_layer,
      )
    self.layers = nn.Sequential(layers)
    self.ln = norm_layer(hidden_dim)

  def forward(self, x: torch.Tensor, y: torch.Tensor = None, *args, **kwargs):
    torch._assert(x.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {x.shape}")
    return self.ln(self.layers(self.dropout(x)))


class EncoderBlock(nn.Module):
  """Transformer encoder block."""

  def __init__(
          self,
          num_heads: int,
          hidden_dim: int,
          mlp_dim: int,
          dropout: float,
          attention_dropout: float,
          norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  ):
    super().__init__()
    self.num_heads = num_heads

    # Attention block
    self.ln_1 = norm_layer(hidden_dim)
    self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
    self.dropout = nn.Dropout(dropout)

    # MLP block
    self.ln_2 = norm_layer(hidden_dim)
    self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

  def forward(self, input: torch.Tensor):
    torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
    x = self.ln_1(input)
    x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
    x = self.dropout(x)
    x = x + input

    y = self.ln_2(x)
    y = self.mlp(y)
    return x + y
