from copy import deepcopy
import torch.nn as nn
import torch
import math

from model.tree_encoder import _get_activation_fn


class RevDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, max_positions=200, dropout=0.1, activation='gelu'):
        super().__init__()
        self.num_attention_heads = nhead
        self.attention_head_size = int(d_model / nhead)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        
        self.Q = nn.Linear(d_model, d_model)
        self.K = nn.Linear(d_model, d_model)
        self.V = nn.Linear(d_model, d_model)
        
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.position_embedding = nn.Embedding(max_positions, d_model)

        self.activation = _get_activation_fn(activation)
        
    def split_heads(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x
    
    def multi_head_attention(self, hidden_states, target, attn_mask):
        """
        :param emb: (N, max_len, dim)
        """
        kv_pos_ids = torch.arange(1, target.shape[2] + 1, device=hidden_states.device)
        # (D)
        q_pos_ids = torch.zeros((hidden_states.shape[0], 1), dtype=torch.long, 
                                device=hidden_states.device)  # (N, 1)
        q_pos_embedding = self.position_embedding(q_pos_ids)  # (N, 1, dim)
        kv_pos_embeddings = self.position_embedding(kv_pos_ids)  # (D, dim)
        kv_pos_embeddings = kv_pos_embeddings.unsqueeze(0).unsqueeze(1)
        q = self.Q(hidden_states) + q_pos_embedding  # (N, L, dim)
        k = self.K(target) + kv_pos_embeddings  # (N, L, D, dim)
        v = self.V(target)  # (N, L, D, dim)

        query_layer = self.split_heads(q)  # (N, L, nhead, dim)
        key_layer = self.split_heads(k)  # (N, L, D, nhead, dim)
        value_layer = self.split_heads(v)  # (N, L, D, nhead, dim)
        query_layer = query_layer.permute(0, 2, 1, 3)  # # (N, nhead, L, dim)
        key_layer = key_layer.permute(0, 3, 1, 2, 4)  # (N, nhead, L, D, dim)
        value_layer = value_layer.permute(0, 3, 1, 2,4)  # (N, nhead, L, D, dim)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer.unsqueeze(-2), key_layer.transpose(-1, -2))  # (N, nhead, L, 1, D)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attn_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            # attn_mask: (N, L, D)
            attention_scores = attention_scores + attn_mask.unsqueeze(1).unsqueeze(-2)

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)  # (N, nhead, L, 1, D)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)  # (N, nhead, L, 1, dim)
        context_layer = context_layer.squeeze(-2)  # (N, nhead, L, dim)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer
        
    def forward(self, src, target, attn_mask):
        src2 = self.multi_head_attention(src, target, attn_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        # save memory        
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


class RevDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        layer = RevDecoderLayer(config.hidden_size,
                                config.num_attention_heads,
                                config.intermediate_size,
                                dropout=config.attention_probs_dropout_prob,
                                activation='gelu')
        self.layers = nn.ModuleList([layer] + [deepcopy(layer) for _ in range(config.encoder_num_hidden_layers - 1)])
        self._device = None
        
    @property
    def device(self):
        if self._device is None:
            self._device = next(self.parameters()).device
        return self._device
        
    def forward(self, src, get_target, attn_mask):
        outputs = src
        for mod in self.layers:
            outputs = mod(outputs, get_target(outputs), attn_mask)
            
        return outputs