import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
from torch.autograd import Variable

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class embeddings(nn.Module):
    def __init__(self, d_model: int, vocab: int):
        super(embeddings, self).__init__()
        self.embedding = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):  # x: tensor[col]  => embedding: tensor[col,d_model]
        return self.embedding(x) * math.sqrt(self.d_model)


class Positional_Encoding(nn.Module):  # More: Rotary Position Embedding
    def __init__(self, d_model: int, dropout: float, len_position=512):
        super(Positional_Encoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.PE = torch.zeros(len_position, d_model, device=device)  # tensor[len_pos,d_model]
        position = torch.arange(0., len_position).unsqueeze(1)  # tensor[len_pos,1]
        div_term = torch.exp(torch.arange(0., d_model, 2) * (-(math.log(10000.0) / d_model)))  # tensor[d_model/2]
        self.PE[:, 0::2] = torch.sin(position * div_term)
        self.PE[:, 1::2] = torch.cos(position * div_term)
        self.PE = self.PE.unsqueeze(0)  # batch track  final=> tensor[1(batch track),len_pos,d_model]
        # self.PE.to(device)

    def forward(self, x):
        # print("device: ", x.is_cuda, self.PE.is_cuda)
        x = x + Variable(self.PE[:, :x.size(1)])
        return self.dropout(x)


def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def Attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill_(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)

    return torch.matmul(p_attn, value), p_attn


class MultiHeadedAttention(nn.Module):
    def __init__(self, head_num: int, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        self.d_k = d_model // head_num
        self.head_num = head_num
        self.dropout = dropout
        self.Dropout = nn.Dropout(p=dropout)
        self.liners = clones(nn.Linear(d_model, d_model), 4)
        self.Attention = None

    def forward(self, query, key, value, mask=None):
        nbatches = query.size(0)
        seq_q_len = query.size(1)
        seq_k_len = key.size(1)
        seq_v_len = value.size(1)
        if mask is not None:
            mask = mask.unsqueeze(1)
        query = self.liners[0](query)
        key = self.liners[1](key)
        value = self.liners[2](value)

        query = query.view(nbatches, seq_q_len, self.head_num, self.d_k)
        key = key.view(nbatches, seq_k_len, self.head_num, self.d_k)
        value = value.view(nbatches, seq_v_len, self.head_num, self.d_k)

        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        x, self.attn = Attention(query, key, value, mask=mask, dropout=self.Dropout)

        x = x.transpose(1, 2).contiguous()
        x = x.view(nbatches, -1, self.head_num * self.d_k)
        return self.liners[-1](x)


class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True) + self.eps
        return self.a_2 * (x - mean) / std + self.b_2


class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_Hidden, dropout=0.1):
        super(PositionWiseFeedForward, self).__init__()
        self.linear_1 = nn.Linear(d_model, d_Hidden)
        self.linear_2 = nn.Linear(d_Hidden, d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.linear_1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.linear_2(x)
        return x


class SubLayerConnection(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super(SubLayerConnection, self).__init__()
        self.LayerNorm = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        # return x + self.dropout(sublayer(self.LayerNorm(x)))
        return self.LayerNorm(x + self.dropout(sublayer(x)))
