import sys
from functools import partial
from typing import Any, Callable, Optional, Text

import torch
import torch.nn as nn
from torchvision.models._api import WeightsEnum

from pyprojroot import here as project_root

sys.path.insert(0, str(project_root()))

from models.blocks import Encoder, MPNNFeatureExtractor


class MoleculeTransformer(nn.Module):
  """Molecular Transformer abstract class."""

  def __init__(
          self,
          atom_dim: int,
          num_layers: int,
          num_heads: int,
          hidden_dim: int,
          mlp_dim: int,
          device: torch.device,
          dropout: float = 0.0,
          attention_dropout: float = 0.0,
          norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  ):
    super().__init__()
    self.atom_dim = atom_dim
    self.hidden_dim = hidden_dim
    self.mlp_dim = mlp_dim
    self.attention_dropout = attention_dropout
    self.dropout = dropout
    self.norm_layer = norm_layer

    self.gnn_extractor = MPNNFeatureExtractor(atom_dim, hidden_dim, num_heads).to(device)

    self.encoder = Encoder(
      num_layers,
      num_heads,
      hidden_dim,
      mlp_dim,
      dropout,
      attention_dropout,
      norm_layer,
    )

  def forward(self, x, y, context_length, *args, **kwargs):
    raise NotImplementedError("this is an abstract class.")


def replicate_batch(x, y):
  """Replicate the batch.

  Args:
    x: torch.Tensor of shape [B, N, C]
    y: torch.Tensor of shape [B, N, 1]
  Returns:
    (x,y,gather_idx) where x has shape [B*N, N, C].

  """
  B, N, C = x.shape

  # Repeat x so that batch_dim is b*n (n=seq length).
  x = x.repeat_interleave(N, dim=0)
  y = y.repeat_interleave(N, dim=0)
  gather_idx = torch.arange(N, device=x.device).repeat(B)

  return x, y, gather_idx


class ContextTransformer_v2(MoleculeTransformer):
  """Context Modeling Molecular Transformer v2.

  This version incorporates the label information into the molecular embedding itself as the first position to consider
  both the molecular embedding as well as the label information when generating attention weights.
  """

  def __init__(
          self,
          atom_dim: int,
          num_layers: int,
          num_heads: int,
          hidden_dim: int,
          mlp_dim: int,
          device: torch.device,
          dropout: float = 0.0,
          attention_dropout: float = 0.0,
          norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  ):
    super().__init__(atom_dim, num_layers, num_heads, hidden_dim, mlp_dim, device, dropout, attention_dropout,
                     norm_layer)

    # Set gnn_extractor to have output_dim = hidden_dim -1 as we later concatenate this rep with the labels.
    self.gnn_extractor = MPNNFeatureExtractor(atom_dim, hidden_dim // 2, num_heads).to(device)
    self.encoder = Encoder(
      num_layers,
      num_heads,
      hidden_dim,
      mlp_dim,
      dropout,
      attention_dropout,
      norm_layer,
    )
    self.class_emb = torch.nn.Linear(in_features=1, out_features=hidden_dim // 2, bias=False)
    self.label_emb = torch.nn.Parameter(torch.zeros(1, 1, hidden_dim // 2))
    self.output_proj = torch.nn.Linear(in_features=hidden_dim, out_features=1, bias=False)

  def forward(self, x, y, context_length, *args, **kwargs):
    # Step 1: Extract molecular graph features with a GNN.
    x = self.gnn_extractor(x)
    y = self.class_emb(torch.unsqueeze(y, 1))

    # Step 2: Map the first extracted features to a "guess" label.
    B, C = x.shape
    x = x.reshape(B // context_length, context_length, C)
    y = y.reshape(B // context_length, context_length, C)

    # Repeat each example |context| so that we predict on each molecule in the context.
    x, y, gather_idx = replicate_batch(x, y)

    # Add positional Embeddings.
    query_mask = torch.eye(context_length, device=x.device, dtype=torch.bool).repeat(B // context_length, 1).unsqueeze(
      -1)
    y += self.label_emb * query_mask - y * query_mask  # Replace the true label w/ the masked token.

    # Concatenate molecule embeddings and label embeddings along last axis.
    x = torch.cat((x, y), dim=-1)

    # Step 3: Pass inputs and labels through the context model.
    x = self.encoder(x, None, **kwargs)

    # Step 4: Extract refined "guess" label.
    y = x
    y = torch.take_along_dim(y, gather_idx.reshape(-1, 1, 1), 1).squeeze()

    # Step 7: Linear projection to determine the label.
    return self.output_proj(y)

  def forward_test(self, train_examples, test_examples, train_labels, test_labels, context_length):
    # Step 1: Extract molecular graph features with a GNN.
    train_examples = self.gnn_extractor(train_examples)
    test_examples = self.gnn_extractor(test_examples)

    B, C = train_examples.shape
    train_examples = train_examples.reshape(1, context_length, C)
    train_labels = self.class_emb(torch.unsqueeze(train_labels, -1)).reshape(1, context_length, -1)

    B, _ = test_examples.shape
    train_examples = train_examples.repeat((B, 1, 1))
    train_labels = train_labels.repeat((B, 1, 1))
    train_examples[:, 0, :] = test_examples

    x = train_examples
    y = train_labels
    y[:, 0, :] = self.label_emb
    x = torch.cat((x, y), dim=-1)

    # Step 3: Pass inputs and labels through the context model.
    x = self.encoder(x, None)
    y = x[:, 0, :]

    # Step 7: Linear projection to determine the label.
    return self.output_proj(y)


def _molecule_transformer(
        atom_dim: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        weights: Optional[WeightsEnum],
        progress: bool,
        device: torch.device,
        model_type: Text = 'MoleculeTransformer',
        **kwargs: Any,
) -> MoleculeTransformer:
  if model_type == 'ContextTransformer_v1':
    raise Exception("Not released for Neurips.")
  elif model_type == 'ContextTransformer_v2':
    model = ContextTransformer_v2(
      atom_dim=atom_dim,
      num_layers=num_layers,
      num_heads=num_heads,
      hidden_dim=hidden_dim,
      mlp_dim=mlp_dim,
      device=device,
      **kwargs,
    )
  elif model_type == 'ContextTransformer_v3':
    raise Exception("Not released for Neurips.")
  else:
    raise Exception(f'model type: {model_type} is not recognized.')
  model.to(device)
  return model


def mt_small_32(*, device=torch.device('cuda:0'), weights=None,
                progress: bool = True, model_type: Text = 'MoleculeTransformer',
                **kwargs: Any) -> MoleculeTransformer:
  return _molecule_transformer(
    atom_dim=32,
    num_layers=1,
    num_heads=1,
    hidden_dim=128,
    mlp_dim=128,
    weights=weights,
    progress=progress,
    device=device,
    model_type=model_type,
    **kwargs,
  )


def mt_medium_32(*, device=torch.device('cuda:0'), weights=None,
                 progress: bool = True, model_type: Text = 'MoleculeTransformer',
                 **kwargs: Any) -> MoleculeTransformer:
  return _molecule_transformer(
    atom_dim=32,
    num_layers=8,
    num_heads=8,
    hidden_dim=512,
    mlp_dim=2048,
    weights=weights,
    progress=progress,
    device=device,
    model_type=model_type,
    **kwargs,
  )


def mt_base_32(*, device=torch.device('cuda:0'), weights=None,
               progress: bool = True, model_type: Text = 'MoleculeTransformer',
               **kwargs: Any) -> MoleculeTransformer:
  return _molecule_transformer(
    atom_dim=32,
    num_layers=12,
    num_heads=12,
    hidden_dim=768,
    mlp_dim=3072,
    weights=weights,
    progress=progress,
    device=device,
    model_type=model_type,
    **kwargs,
  )


def mt_large_32(*, device=torch.device('cuda:0'), weights=None,
                progress: bool = True, model_type: Text = 'MoleculeTransformer',
                **kwargs: Any) -> MoleculeTransformer:
  return _molecule_transformer(
    atom_dim=32,
    num_layers=24,
    num_heads=16,
    hidden_dim=1024,
    mlp_dim=4096,
    weights=weights,
    progress=progress,
    device=device,
    model_type=model_type,
    **kwargs,
  )


def mt_huge_32(*, device=torch.device('cuda:0'), weights=None, progress: bool = True,
               model_type: Text = 'MoleculeTransformer',
               **kwargs: Any) -> MoleculeTransformer:
  return _molecule_transformer(
    atom_dim=32,
    num_layers=32,
    num_heads=16,
    hidden_dim=1280,
    mlp_dim=5120,
    weights=weights,
    progress=progress,
    device=device,
    model_type=model_type,
    **kwargs,
  )
