import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import ipdb as pdb
import numpy as np
import copy


def clones(module, N):
	"Produce N identical layers."
	return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def head_split(x, head_dim: int):
    x = x.reshape(*x.shape[:2], -1, head_dim)
    return x


def self_attention(query, key, value, mask):
    '''
    Args:
        query: Tensor of shape [batch, seq, num_heads, emb_dim / num_heads].
        value: Tensor of shape [batch, seq, num_heads, emb_dim / num_heads].
        key: Tensor of shape [batch, seq, num_heads, emb_dim / num_heads].
        mask: Tensor of shape [1, 1, seq, seq]
    '''
    
    head_dim = query.size(-1)
    
    attention_logits = torch.einsum("bthd, bThd -> bhtT", query, key)  # [B N L L]
    attention_logits /= math.sqrt(head_dim)
    
    if mask is not None:
        min_value = torch.finfo(attention_logits.dtype).min
        attention_logits = attention_logits.masked_fill(mask ==0, min_value)
    
    
    attention_weights = F.softmax(attention_logits, -1)     # [B N L L]
    
    attention_vec = torch.einsum("bhtT,bThd->bthd", attention_weights, value)
    return attention_vec
    


class LearnablePositionalEncoding(nn.Module):

	def __init__(self, d_model, dropout=0.0, max_len=200, init_range = 0.1):
		super(LearnablePositionalEncoding, self).__init__()
		self.dropout = nn.Dropout(p=dropout)
		# pos_embeds = torch.FloatTensor(max_len, 1, d_model).uniform_(-init_range, init_range)
        # define and initialize learnable encodings
		pos_embeds = torch.FloatTensor(1, max_len, d_model).normal_(0, 1)
		pe = nn.Parameter(pos_embeds, requires_grad = True)
		self.pe = pe

	def forward(self, x):
		
		x = x + self.pe[:, :x.size(1), :]
		return self.dropout(x)


class AbsolutePositionalEncoding(nn.Module):
	r"""Inject some information about the relative or absolute position of the tokens
		in the sequence. The positional encodings have the same dimension as
		the embeddings, so that the two can be summed. Here, we use sine and cosine
		functions of different frequencies.
	.. math::
		\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
		\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
		\text{where pos is the word position and i is the embed idx)
	Args:
		d_model: the embed dim (required).
		dropout: the dropout value (default=0.1).
		max_len: the max. length of the incoming sequence (default=5000).
	Examples:
		>>> pos_encoder = PositionalEncoding(d_model)
	"""

	def __init__(self, d_model, dropout=0.1, max_period = 10000.0, max_len=500):
		super(AbsolutePositionalEncoding, self).__init__()
		odd_flag=False
		if int(d_model%2) !=0:
			odd_flag=True
		self.dropout = nn.Dropout(p=dropout)
		pe = torch.zeros(max_len, d_model)
		position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
		div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(max_period) / d_model))
		pe[:, 0::2] = torch.sin(position * div_term)
		if odd_flag:
			pe[:, 1::2] = torch.cos(position * div_term[:-1])
		else:
			pe[:, 1::2] = torch.cos(position * div_term)

		pe = pe.unsqueeze(0)
		self.register_buffer('pe', pe)

	def forward(self, x):
		r"""Inputs of forward function
		Args:
			x: the sequence fed to the positional encoder model (required).
		Shape:
			x: [ batch size, seq, embed dim]
			output: [ batch size, seq, embed dim]
		Examples:
			>>> output = pos_encoder(x)
		"""

		x = x + self.pe[:, :x.size(1), :]
		return self.dropout(x)
