
import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class SelfAttention(nn.Module):
    """ Self-Attention """

    def __init__(self, input_dim, output_dim, n_head, dropout):
        super(SelfAttention, self).__init__()
        self.wq = nn.Parameter(torch.Tensor(input_dim, output_dim))
        self.wk = nn.Parameter(torch.Tensor(input_dim, output_dim))
        self.wv = nn.Parameter(torch.Tensor(input_dim, output_dim))

        self.mha = nn.MultiheadAttention(embed_dim=output_dim, num_heads=n_head, dropout=dropout)

        self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / np.power(param.size(-1), 0.5)
            param.data.uniform_(-stdv, stdv)

    def forward(self, x):
        q = torch.matmul(x, self.wq)
        k = torch.matmul(x, self.wk)
        v = torch.matmul(x, self.wv)
        output, _ = self.mha(q, k, v)
        return output


class AttentionBlock(nn.Module):
    def __init__(self, dim, n_head, dropout=0.):
        super(AttentionBlock, self).__init__()
        self.self_attention = SelfAttention(dim, dim, n_head, dropout)
        self.fc = nn.Linear(dim, dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm([dim]),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim, dim),
        )
        self.norm = nn.LayerNorm([dim])

    def forward(self, x):
        out = self.self_attention(x)
        x = self.norm(x + out)
        out = self.ffn(x)
        return self.norm(x + out)


class AttentionMLP(nn.Module):
    def __init__(self, input_dim=38, hidden_dim=64, dropout=0.2):
        super(AttentionMLP, self).__init__()

        self.hidden_dim = hidden_dim
        self.input_dim = input_dim

        self.fc_head = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                                        nn.LayerNorm([hidden_dim]),
                                        nn.GELU(),
                                        nn.Linear(hidden_dim, hidden_dim),
                                        nn.LayerNorm([hidden_dim]),
                                        nn.GELU()
                                        )
        self.attention = nn.Sequential(AttentionBlock(dim=hidden_dim, n_head=8, dropout=dropout),
                                       AttentionBlock(dim=hidden_dim, n_head=8, dropout=dropout),
                                       AttentionBlock(dim=hidden_dim, n_head=8, dropout=dropout),
                                       AttentionBlock(dim=hidden_dim, n_head=8, dropout=dropout),
                                       )

        self.output_fc = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        x = self.fc_head(x)
        x = self.attention(x)
        x = self.output_fc(x)
        return x


if __name__ == "__main__":
    net = AttentionMLP()
    x = torch.randn(256, 38)
    y = net(x)
    print(y.shape)