import torch
from torch import nn

from .peak_embedding import FeedForwardBlock

class SelfAttentionLayer(nn.Module):
    def __init__(self, d_model, nhead, ff_dim, dropout=0.1, collapse_seq=False):
        super(SelfAttentionLayer, self).__init__()

        self._collapse_seq = collapse_seq

        self._self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self._dropout = nn.Dropout(dropout)

        self._ff_block = FeedForwardBlock(d_model, d_model, ff_dim, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        if self._collapse_seq:
            x = self.norm1(x[:, 0, :] + self._sa_block(x))
        else:
            x = self.norm1(x + self._sa_block(x))

        x = self.norm2(x + self._ff_block(x))

        return x

    def _sa_block(self, x):
        x = torch.transpose(x, 0, 1)

        if self._collapse_seq:
            q = x[:1, :, :]
            x = self._self_attn(q, x, x, need_weights=False)[0]
            x = x[0, :,:]
        else:
            x = self._self_attn(x, x, x, need_weights=False)[0]
            x = torch.transpose(x, 0, 1)

        return self._dropout(x)

class AttentionHead(nn.Module):
    def __init__(self, embd_dim, num_heads, ff_dim, dropout=0.1):
        super(AttentionHead, self).__init__()

        self._self_attn = nn.MultiheadAttention(embd_dim, num_heads, dropout=dropout)
        self._dropout = nn.Dropout(dropout)
        self._ff_block = FeedForwardBlock(embd_dim, embd_dim, ff_dim, dropout)
        self._norm1 = nn.LayerNorm(embd_dim)

    def forward(self, x):
        x = self._sa_block(x)
        x = self._ff_block(x)

        return x

    def _sa_block(self, x):
        x = torch.transpose(x, 0, 1)

        q = x[:1, :, :]
        x = self._self_attn(q, x, x, need_weights=False)[0]
        x = x[0, :,:]

        x = self._norm1(self._dropout(x))

        return x

class MeanPool(nn.Module):
    def __init__(self, embd_dim, num_heads, ff_dim, dropout=0.1):
        super(MeanPool, self).__init__()

        self._sa_layer = SelfAttentionLayer(embd_dim, num_heads, ff_dim, dropout)

    def forward(self, x):
        # need to figure out how to remove missing
        return torch.mean(self._sa_layer(x), 1)

class SANet(nn.Module):
    def __init__(
            self, peak_embedding, attention_head, num_sa_layers, embd_dim, num_heads, ff_dim,
            dropout):
        super(SANet, self).__init__()

        self._layers = [peak_embedding]
        for _ in range(num_sa_layers - 1):
            self._layers.append(SelfAttentionLayer(embd_dim, num_heads, ff_dim, dropout))
        self._layers.append(attention_head)

        self._layers = nn.Sequential(*self._layers)

    def forward(self, x):
        return self._layers(x)
