import copy
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


### Model definition


def make_model(d_atom, cutoff, edge_dim, num_radial=16,
               N=2, d_model=128, h=8, dropout=0.1,
               N_dense=2, aggregation_type='grover',
               n_output=1, n_generator_layers=1,
               generator_attn_hidden=128, generator_attn_heads=4, generator_d_linear=32,
               **kwargs):
    "Helper: Construct a model from hyperparameters."
    
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model, edge_dim, dropout)
    ff = PositionwiseFeedForward(d_model, N_dense, dropout)
    model = GraphTransformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Embeddings(d_model, d_atom, dropout),
        Generator(d_model, aggregation_type, n_output, n_generator_layers, dropout, 
                  generator_attn_hidden, generator_attn_heads, generator_d_linear),
        BesselBasisLayerEnvelope(num_radial=num_radial, cutoff=cutoff))
        
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model


class GraphTransformer(nn.Module):
    def __init__(self, encoder, src_embed, generator, dist_rbf):
        super(GraphTransformer, self).__init__()
        self.encoder = encoder
        self.src_embed = src_embed
        self.generator = generator
        self.dist_rbf = dist_rbf
                
    def forward(self, src, distances_matrix, edges_att):
        "Take in and process masked src and target sequences."
        src_mask = torch.sum(torch.abs(src), dim=-1) != 0
        distances_matrix = self.dist_rbf(distances_matrix)
        edges_att = torch.cat((edges_att, distances_matrix), dim=1)
        return self.predict(self.encode(src, src_mask, edges_att), src_mask)
    
    def encode(self, src, src_mask, edges_att):
        return self.encoder(self.src_embed(src), src_mask, edges_att)
    
    def predict(self, out, out_mask):
        return self.generator(out, out_mask)
    
    def embed_tokens(self, src):
        return self.src_embed(src)
    
    def forward_embeds(self, embed, src_mask, distances_matrix, edges_att):
        distances_matrix = self.dist_rbf(distances_matrix)
        edges_att = torch.cat((edges_att, distances_matrix), dim=1)
        return self.predict(self.encoder(embed, src_mask, edges_att), src_mask)
    
    def from_pretrained(self, weights_path: str):
        excluded_prefixes = ("generator.",)
        pretrained_dict = torch.load(weights_path, map_location='cuda:0')

        excluded_keys = [
            key
            for key in pretrained_dict.keys()
            if any(key.startswith(prefix) for prefix in excluded_prefixes)
        ]

        for excluded_key in excluded_keys:
            del pretrained_dict[excluded_key]

        model_dict = self.state_dict()
        model_dict.update(pretrained_dict)

        self.load_state_dict(model_dict)


class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, aggregation_type='grover', n_output=1, n_layers=1, dropout=0.0, 
                 attn_hidden = 128, attn_heads = 4, d_linear=128):
        super(Generator, self).__init__()
        self.aggregation_type = aggregation_type
            
        if self.aggregation_type == 'molbert_pretraining':
            self.supervised_proj = nn.Linear(d_model, 200)
        
        if aggregation_type == 'grover':
            self.att_net = nn.Sequential(
                              nn.Linear(d_model, attn_hidden, bias=False),
                              nn.Tanh(),
                              nn.Linear(attn_hidden, attn_heads, bias=False),
                            )
            d_model *= attn_heads

        if n_layers == 1:
            self.proj = nn.Linear(d_model, n_output)
        else:
            self.proj = []
            self.proj.append(nn.Linear(d_model, d_linear))
            self.proj.append(nn.LeakyReLU(negative_slope=0.1))
            self.proj.append(LayerNorm(d_linear))
            self.proj.append(nn.Dropout(dropout))
                
            for i in range(n_layers-2):
                self.proj.append(nn.Linear(d_linear, d_linear))
                self.proj.append(nn.LeakyReLU(negative_slope=0.1))
                self.proj.append(LayerNorm(d_linear))
                self.proj.append(nn.Dropout(dropout))
                
            self.proj.append(nn.Linear(d_linear, n_output))
            self.proj = torch.nn.Sequential(*self.proj)
        
        
    def forward(self, x, mask):
        mask = mask.unsqueeze(-1).float()
        out_masked = x * mask
        
        if self.aggregation_type == 'molbert_pretraining':
            out_sum = out_masked.sum(dim=1)
            mask_sum = mask.sum(dim=(1))
            out_avg_pooling = out_sum / mask_sum
            supervised_projected = self.supervised_proj(out_avg_pooling)
            contextual_projected = self.proj(x)
            
            return supervised_projected, contextual_projected
        
        if self.aggregation_type == 'mean':
            out_sum = out_masked.sum(dim=1)
            mask_sum = mask.sum(dim=(1))
            out_avg_pooling = out_sum / mask_sum
        elif self.aggregation_type == 'grover':
            out_attn = self.att_net(out_masked)
            out_attn = out_attn.masked_fill(mask == 0, -1e9)
            out_attn = F.softmax(out_attn, dim=1)
            out_avg_pooling = torch.matmul(torch.transpose(out_attn, -1, -2), out_masked)
            out_avg_pooling = out_avg_pooling.view(out_avg_pooling.size(0), -1)
        elif self.aggregation_type == 'contextual':
            out_avg_pooling = x
            
        projected = self.proj(out_avg_pooling)
        return projected
    
    
### Encoder

def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class Encoder(nn.Module):
    "Core encoder is a stack of N layers"
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask, edges_att):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            x = layer(x, mask, edges_att)
        return self.norm(x)

    
class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

    
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

    
class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask, edges_att):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, edges_att, mask))
        return self.sublayer[1](x, self.feed_forward)

    
### Attention           

class EdgeFeaturesLayer(nn.Module):
    def __init__(self, d_edge, d_out, d_hidden, h, dropout):
        super(EdgeFeaturesLayer, self).__init__()
        self.d_k = d_out // h
        self.h = h
        
        self.nn = nn.Sequential(
                          nn.Linear(d_edge, d_hidden),
                          nn.LeakyReLU(negative_slope=0.1),
                          nn.Dropout(dropout),
                          nn.Linear(d_hidden, d_out),
                        )

    def forward(self, p_edge):
        p_edge = p_edge.permute(0, 2, 3, 1)
        p_edge = self.nn(p_edge).permute(0, 3, 1, 2)
        p_edge = p_edge.view(p_edge.size(0), self.h, self.d_k, p_edge.size(2), p_edge.size(3))
        return p_edge


def attention(query, key, value, 
              relative_K, relative_V,
              relative_u, relative_v,
              mask=None, dropout=None, eps=1e-6, inf=1e12):
    "Compute 'Scaled Dot Product Attention'"
    b, h, n, d_k = query.size(0), query.size(1), query.size(2), query.size(-1)
    
    scores1 = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    scores2 = torch.matmul((query+key).view(b, h, n, 1, d_k), relative_K.permute(0, 1, 3, 2, 4)).view(b, h, n, n) / math.sqrt(d_k)
    scores3 = torch.matmul(key, relative_u.transpose(-2, -1))
    scores4 = torch.matmul(relative_K.permute(0, 1, 3, 4, 2), relative_v).squeeze(-1)
    scores = scores1 + scores2 + scores3 + scores4 # + scores5
    
    if mask is not None:
        scores = scores.masked_fill(mask.unsqueeze(1).repeat(1, query.shape[1], query.shape[2], 1) == 0, -inf)
    p_attn = F.softmax(scores, dim = -1)
    
    if dropout is not None:
        p_attn = dropout(p_attn)
        
    atoms_features1 = torch.matmul(p_attn, value)
    atoms_features2 = (p_attn.unsqueeze(2) * relative_V).sum(-1).permute(0, 1, 3, 2)

    atoms_features = atoms_features1 + atoms_features2
    
    return atoms_features, p_attn


class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, edge_dim, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
            
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
        self.relative_K = EdgeFeaturesLayer(edge_dim, d_model, self.d_k, h, dropout)
        self.relative_V = EdgeFeaturesLayer(edge_dim, d_model, self.d_k, h, dropout)
        
        self.relative_u = nn.Parameter(torch.empty(1, self.h, 1, self.d_k))
        self.relative_v = nn.Parameter(torch.empty(1, self.h, 1, self.d_k, 1))
        
        
    def forward(self, query, key, value, edges_att, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]
        
        relative_K = self.relative_K(edges_att)
        relative_V = self.relative_V(edges_att)
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, 
                                 relative_K, relative_V,
                                 self.relative_u, self.relative_v,
                                 mask=mask, dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)


### Conv 1x1 aka Positionwise feed forward

class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, N_dense, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.N_dense = N_dense
        lin_factor = 2
        if N_dense == 1:
            self.linears = [nn.Linear(d_model, d_model)]
        else:
            self.linears = [nn.Linear(d_model, d_model*lin_factor)] + [nn.Linear(d_model*lin_factor, d_model*lin_factor) for _ in range(N_dense-2)] + [nn.Linear(d_model*lin_factor, d_model)]
            
        self.linears = nn.ModuleList(self.linears)
        self.dropout = clones(nn.Dropout(dropout), N_dense)
        self.nonlinearity = nn.LeakyReLU(negative_slope=0.1)

    def forward(self, x):
        if self.N_dense == 0:
            return x
        
        for i in range(self.N_dense-1):
            x = self.dropout[i](self.nonlinearity(self.linears[i](x)))

        return self.linears[-1](x)

    
## Embeddings

class Embeddings(nn.Module):
    def __init__(self, d_model, d_atom, dropout):
        super(Embeddings, self).__init__()
        self.d_model = d_model
        self.lut = nn.Linear(d_atom, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.lut(x))


## Distance Layers

class Envelope(nn.Module):
    """
    Envelope function that ensures a smooth cutoff
    """
    def __init__(self, exponent, **kwargs):
        super().__init__(**kwargs)
        self.exponent = exponent

        self.p = exponent + 1
        self.a = -(self.p + 1) * (self.p + 2) / 2
        self.b = self.p * (self.p + 2)
        self.c = -self.p * (self.p + 1) / 2

    def forward(self, inputs):
        # Envelope function divided by r
        env_val = 1 / inputs + self.a * inputs**(self.p - 1) + self.b * inputs**self.p + self.c * inputs**(self.p + 1)

        return torch.where(inputs < 1, env_val, torch.zeros_like(inputs))
    
    
class BesselBasisLayerEnvelope(nn.Module):
    def __init__(self, num_radial, cutoff, envelope_exponent=5, **kwargs):
        super().__init__(**kwargs)
        self.num_radial = num_radial
        self.cutoff = cutoff
        self.sqrt_cutoff = np.sqrt(2. / cutoff)
        self.inv_cutoff = 1. / cutoff
        self.envelope = Envelope(envelope_exponent)
        
        self.frequencies = np.pi * torch.arange(1, num_radial + 1).float().cuda()

    def forward(self, inputs):
        inputs = inputs.unsqueeze(-1) + 1e-6
        d_scaled = inputs * self.inv_cutoff
        d_cutoff = self.envelope(d_scaled)
        return (d_cutoff * torch.sin(self.frequencies * d_scaled)).permute(0, 3, 1, 2)
    
    
