# coding=utf-8
# Copyright (c) 2021 Ant Group
# Author: Xiang Hu

from functools import partial
import math
from typing import List, Optional
import torch.nn as nn
import torch.nn.functional as F
import torch
from copy import deepcopy
import numpy as np
from .r2d2_common import ROLE_LEFT, ROLE_RIGHT

ACTIVATION_POOL = ['relu', 'gelu']


def _get_activation_fn(activation):
    if activation in ACTIVATION_POOL:
        return getattr(F, activation)

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))


class TreeEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout, max_role_count, activation='gelu'):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.InstanceNorm1d(d_model)
        self.norm2 = nn.InstanceNorm1d(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.position_embedding = nn.Embedding(max_role_count, d_model)

        self.activation = _get_activation_fn(activation)

    def forward(self, src, src_mask=None, pos_ids=None):
        """
        :param src: concatenation of task embeddings and representation for left and right.
                    src shape: (task_embeddings + left + right, batch_size, dim)
        :param src_mask:
        :param pos_ids:
        :return:
        """
        if len(pos_ids.shape) == 1:
            sz = src.shape[1]  # sz: batch_size
            pos_ids = pos_ids.unsqueeze(1).expand(-1, sz)  # (3, batch_size)
        position_embedding = self.position_embedding(pos_ids)
        src2 = self.self_attn(src + position_embedding, src + position_embedding, src, attn_mask=src_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


class BinaryEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        layer = TreeEncoderLayer(config.hidden_size,
                                 config.num_attention_heads,
                                 config.intermediate_size,
                                 max_role_count=config.max_role_embeddings,
                                 dropout=config.attention_probs_dropout_prob,
                                 activation='gelu')
        self.layers = nn.ModuleList([layer] + [deepcopy(layer) for _ in range(config.encoder_num_hidden_layers - 1)])
        self._device = None
        self._mask_cache = []
        self._pos_ids_cache = []
    
    @property
    def device(self):
        if self._device is None:
            self._device = next(self.parameters()).device
        return self._device

    def forward(self, src, src_mask=None, pos_ids=None):
        """
        :param pos_ids:
        :param src_mask:
        :param src:
        :return:
        """
        output = src
        task_count = src.shape[1] - 2
        if pos_ids is None:
            while task_count >= len(self._pos_ids_cache):
                self._pos_ids_cache.append(None)
            if self._pos_ids_cache[task_count] is None:
                pos_ids = torch.tensor([0] * task_count + [ROLE_LEFT, ROLE_RIGHT], dtype=torch.long,
                                       device=self.device)
                self._pos_ids_cache[task_count] = pos_ids
            pos_ids = self._pos_ids_cache[task_count]
        if src_mask is None:
            while task_count >= len(self._mask_cache):
                self._mask_cache.append(None)
            if self._mask_cache[task_count] is None:
                src_mask = [[float('-inf') for _ in range(task_count + 2)] for _ in range(task_count + 2)]
                for pos_i in range(task_count + 2):
                    if pos_i < task_count:
                        src_mask[pos_i][pos_i] = 0
                    src_mask[pos_i][-1] = 0
                    src_mask[pos_i][-2] = 0
                src_mask = torch.tensor(src_mask, dtype=torch.float, device=self.device)
                self._mask_cache[task_count] = src_mask
            src_mask = self._mask_cache[task_count]

        output = src.permute(1, 0, 2)
        for mod in self.layers:
            output = mod(output, src_mask, pos_ids)

        return output.permute(1, 0, 2)


# class ContextEncoderLayerLegacy(nn.Module):
#     def __init__(self, d_model, nhead, dim_feedforward, dropout, activation='gelu'):
#         super().__init__()
#         self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
#         self.linear1 = nn.Linear(d_model, dim_feedforward)
#         self.dropout = nn.Dropout(dropout)
#         self.linear2 = nn.Linear(dim_feedforward, d_model)

#         self.norm1 = nn.LayerNorm(d_model)
#         self.norm2 = nn.LayerNorm(d_model)
#         self.dropout1 = nn.Dropout(dropout)
#         self.dropout2 = nn.Dropout(dropout)

#         self.activation = _get_activation_fn(activation)

#     def forward(self, src, sequence_to_batch, batch_to_sequence, attn_mask=None, key_paddig_mask=None):
#         """
#         :param src: concatenation of task embeddings and representation for left and right.
#                     src shape: (task_embeddings + left + right, batch_size, dim)
#         :param src_mask:
#         :param pos_ids:
#         :return:
#         """
#         src2 = self.self_attn(src, src, src, attn_mask=attn_mask, key_padding_mask=key_paddig_mask)[0]
#         src = src + self.dropout1(src2)
#         src = self.norm1(src)
#         temp = batch_to_sequence(src)
#         # save memory        
#         temp = self.linear2(self.dropout(self.activation(self.linear1(temp))))
#         src2 = sequence_to_batch(temp)
#         src = src + self.dropout2(src2)
#         src = self.norm2(src)
#         return src


# class UniLMEncoderLegacy(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         layer = ContextEncoderLayerLegacy(config.hidden_size,
#                                     config.num_attention_heads,
#                                     config.intermediate_size,
#                                     dropout=config.attention_probs_dropout_prob,
#                                     activation='gelu')
#         self.max_positions = config.max_positions
#         self.position_embeddings = nn.Embedding(config.max_positions * 2 + 2, config.embedding_dim)
#         self.layers = nn.ModuleList([layer] + [deepcopy(layer) for _ in range(config.encoder_num_hidden_layers - 1)])
#         self._device = None
        
#     @property
#     def device(self):
#         if self._device is None:
#             self._device = next(self.parameters()).device
#         return self._device
    
#     def convert_sequence_to_batch(self, embeddings, gather_indices):
#         dim = embeddings.shape[-1]
#         return embeddings[gather_indices.flatten()].view(*gather_indices.shape, dim)

#     def convert_batch_to_sequence(self, batch, indices_mapping):
#         return batch[indices_mapping[:, 0], indices_mapping[:, 1], :]

#     def forward(self, 
#                 input_ids_batch: torch.Tensor = None,
#                 memory: torch.Tensor = None, 
#                 input_ids: List[List[int]] = None,
#                 embeddings: nn.Embedding = None,
#                 bidirectional_pos: bool = True):
#         """_summary_

#         Args:
#             input_ids_batch: (N, len), batchified input ids
#             input_ids (List[List[int]]): unaligned input ids, len(input_ids) == N
#             memory (torch.Tensor): (N, 2, dim)
#             embeddings (nn.Embedding): _description_
#         Return:
#             encoded_memory: torch.Tensor: (N, 2, dim)
#             outputs: (total(input_ids), dim), encoding result of all input_ids
#         """
#         assert input_ids_batch is not None or input_ids is not None
#         memory_len = memory.shape[1]
#         assert memory_len == 2
#         if input_ids_batch is not None:
#             input_embedding = embeddings(input_ids_batch)  # (N, len, dim)
#             if not bidirectional_pos:
#                 position_ids = torch.arange(0, memory_len + input_ids_batch.shape[-1],
#                                             device=self.device)
#                 position_embeddings = self.position_embeddings(position_ids)
#             else:
#                 pos_ids = []
#                 for memory_pos in range(memory_len):
#                     pos_ids.append(memory_pos)
#                 for id_pos in range(input_ids_batch.shape[-1]):
#                     pos_ids.append(id_pos + memory_len)
#                 for id_pos in range(input_ids_batch.shape[-1]):
#                     pos_ids.append(self.max_positions + memory_len + input_ids_batch.shape[-1] - id_pos - 1)
#                 all_pos_embedding = self.position_embeddings(torch.tensor(pos_ids, device=self.device))
#                 split_pos = memory_len + input_ids_batch.shape[-1]
#                 pos_embeddings = all_pos_embedding[memory_len: split_pos] + all_pos_embedding[split_pos:]
#                 position_embeddings = torch.cat([all_pos_embedding[:memory_len], pos_embeddings], dim=0)
#             # (L, dim)
#             hidden_states = torch.cat([memory, input_embedding], dim=1)
#             hidden_states = hidden_states + position_embeddings.unsqueeze(0)
#             convert_b2s = lambda x: x
#             convert_s2b = lambda x: x
#             max_ids_len = input_ids_batch.shape[1] + memory.shape[1]
#             key_padding_masks = None
#         else:
#             input_ids_flatten = []
#             assert len(input_ids) == memory.shape[0]
#             gather_indices = []
#             offset = memory.shape[0] * memory.shape[1]  # N * L
#             max_ids_len = 0
#             reverse_indices = []
#             for batch_i in range(len(input_ids)):
#                 for m_idx in range(memory_len):
#                     reverse_indices.append([batch_i, m_idx])
#             fwd_pos_ids_flatten = []
#             bwd_pos_ids_flatten = []
#             for batch_i, ids in enumerate(input_ids):
#                 indices = [memory_len * batch_i + m_idx for m_idx in range(memory_len) ]
#                 for id_pos, id in enumerate(ids):
#                     input_ids_flatten.append(id)
#                     fwd_pos_ids_flatten.append(id_pos + memory_len)
#                     assert id_pos < self.max_positions
#                     bwd_pos_ids_flatten.append(len(ids) - id_pos - 1 + self.max_positions + memory_len)
#                     indices.append(offset)
#                     offset += 1
#                 for order in range(memory_len, len(ids) + memory_len):
#                     reverse_indices.append([batch_i, order])
#                 max_ids_len = max(max_ids_len, len(indices))
#                 gather_indices.append(indices)
#             #padding gather indices
#             key_padding_masks = []
#             for indices in gather_indices:
#                 key_padding_masks.append([False] * len(indices) + (max_ids_len - len(indices)) * [True])
#                 if len(indices) < max_ids_len:
#                     indices.extend([indices[-1]] * (max_ids_len - len(indices)))
#             input_ids_flatten = torch.tensor(input_ids_flatten, device=self.device)
#             fwd_pos_ids_flatten = torch.tensor(fwd_pos_ids_flatten, device=self.device)
#             bwd_pos_ids_flatten = torch.tensor(bwd_pos_ids_flatten, device=self.device)
#             input_embedding = embeddings(input_ids_flatten)  # (total_len, dim)
#             if not bidirectional_pos:
#                 fwd_pos_embedding = self.position_embeddings(fwd_pos_ids_flatten)
#                 input_embedding = input_embedding + fwd_pos_embedding
#             else:
#                 fwd_pos_embedding = self.position_embeddings(fwd_pos_ids_flatten)
#                 bwd_pos_embedding = self.position_embeddings(bwd_pos_ids_flatten)
#                 input_embedding = input_embedding + fwd_pos_embedding + bwd_pos_embedding
            
#             # memory: (N, 2, dim)
#             memory_pos_ids = torch.arange(0, memory_len, device=self.device)
#             memory_pos_embedding = self.position_embeddings(memory_pos_ids)
#             memory = memory + memory_pos_embedding
#             memory_flatten = memory.view(-1, memory.shape[-1])  # (N * 2, dim)
#             hidden_states = torch.cat([memory_flatten, input_embedding], dim=0)
#             # (total_len + N * 2, dim)
#             gather_indices = torch.tensor(gather_indices, device=self.device)
#             reverse_indices = torch.tensor(reverse_indices, device=self.device)
#             convert_b2s = partial(self.convert_batch_to_sequence, indices_mapping=reverse_indices)
#             convert_s2b = partial(self.convert_sequence_to_batch, gather_indices=gather_indices)
#             hidden_states = convert_s2b(hidden_states)
#             # (N, max_ids_len)
#         attn_mask = None
#         if key_padding_masks is not None:
#             key_padding_masks = torch.tensor(key_padding_masks, dtype=torch.bool, 
#                                                 device=self.device)
        
#         for mod in self.layers:
#             hidden_states = mod(hidden_states, convert_s2b, convert_b2s, 
#                                 attn_mask=attn_mask, key_paddig_mask=key_padding_masks)
            
#         return hidden_states[:, memory_len:, :]


class ContextEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout, activation='gelu'):
        super().__init__()
        self.num_attention_heads = nhead
        self.attention_head_size = int(d_model / nhead)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        
        self.Q = nn.Linear(d_model, d_model)
        self.K = nn.Linear(d_model, d_model)
        self.V = nn.Linear(d_model, d_model)
        
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        
    def split_heads(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x
    
    def multi_head_attention(self, hidden_states, target, attn_mask):
        """
        :param emb: (max_len, dim)
        attn_mask: (L, D)
        """
        q = self.Q(hidden_states)  # (L, dim)
        k = self.K(target)  # (L, D, dim)
        v = self.V(target)  # (L, D, dim)

        query_layer = self.split_heads(q)  # (L, nhead, dim)
        key_layer = self.split_heads(k)  # (L, D, nhead, dim)
        value_layer = self.split_heads(v)  # (L, D, nhead, dim)
        query_layer = query_layer.permute(1, 0, 2)  # # (nhead, L, dim)
        key_layer = key_layer.permute(2, 0, 1, 3)  # (nhead, L, D, dim)
        value_layer = value_layer.permute(2, 0, 1,3)  # (nhead, L, D, dim)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer.unsqueeze(-2), key_layer.transpose(-1, -2))  # (nhead, L, 1, D)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attn_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            # attn_mask: (N, L, D)
            attention_scores = attention_scores + attn_mask.unsqueeze(0).unsqueeze(-2)

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)  # (nhead, L, 1, D)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)  # (nhead, L, 1, dim)
        context_layer = context_layer.squeeze(-2)  # (nhead, L, dim)
        context_layer = context_layer.permute(1, 0, 2).contiguous()  # (L, nhead, dim)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer
        
    def forward(self, src, target, attn_mask):
        src2 = self.multi_head_attention(src, target, attn_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        # save memory        
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src


class UniLMEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        layer = ContextEncoderLayer(config.hidden_size,
                                    config.num_attention_heads,
                                    config.intermediate_size,
                                    dropout=config.attention_probs_dropout_prob,
                                    activation='gelu')
        self.max_positions = config.max_positions
        self.position_embeddings = nn.Embedding(config.max_positions * 2 + 2, config.embedding_dim)
        self.layers = nn.ModuleList([layer] + [deepcopy(layer) for _ in range(config.encoder_num_hidden_layers - 1)])
        self._device = None

    @property
    def device(self):
        if self._device is None:
            self._device = next(self.parameters()).device
        return self._device

    def forward(self, input_ids: List[List[int]] = None,
                memory: torch.Tensor = None, 
                embeddings: nn.Embedding = None,
                bidirectional_pos: bool = True):
        # memory : (N, 2, dim)
        assert len(input_ids) == memory.shape[0]
        N = memory.shape[0]
        dim = memory.shape[-1]
        mem_len = memory.shape[1]
        memory = memory.view(-1, dim)  # (N * 2, dim)
        flatten_ids = []
        max_key_len = 0
        pos_ids_batch = []
        offset = N * 2
        mem_pos_ids = [mem_pos % mem_len for mem_pos in range(mem_len * N)]
        fwd_pos_ids = []
        bwd_pos_ids = []
        for ids in input_ids:
            pos_ids = [offset + _ for _ in range(len(ids))]
            offset += len(ids)
            flatten_ids.extend(ids)
            max_key_len = max(len(ids) + mem_len, max_key_len)
            pos_ids_batch.append(pos_ids)
            for id_pos in range(len(ids)):
                fwd_pos_ids.append(id_pos + mem_len)
                bwd_pos_ids.append(len(ids) - 1 - id_pos + mem_len + self.max_positions)

        total_len =  2 * N + len(flatten_ids)
        gather_indices = np.zeros((total_len, max_key_len))
        mask = torch.full((total_len, max_key_len), fill_value=-np.inf, device=self.device)

        offset = 2 * N
        return_indices = []
        for sent_i in range(N):
            key_pos = [sent_i * mem_len + mem_idx for mem_idx in range(mem_len)]
            key_pos.extend(pos_ids_batch[sent_i])  # (mem_idx, context)
            key_len = len(key_pos)
            if len(key_pos) < max_key_len:
                key_pos.extend([pos_ids_batch[sent_i][-1]] * (max_key_len - key_len))
            return_indices.append(key_pos[mem_len:])
            gather_indices[sent_i * mem_len: mem_len * (sent_i + 1), :] = key_pos
            mask[sent_i * mem_len: mem_len * (sent_i + 1), key_len:] = 1

            ids = input_ids[sent_i]
            gather_indices[offset: offset + len(ids), :] = key_pos
            mask[sent_i * mem_len: mem_len * (sent_i + 1), :key_len] = 0
            mask[offset: offset + len(ids), :key_len] = 0
            offset += len(ids)

        flatten_ids = torch.tensor(flatten_ids, device=self.device)
        if not bidirectional_pos:
            fwd_pos_ids = torch.tensor(fwd_pos_ids, device=self.device, dtype=torch.long)
            input_embedding = embeddings(flatten_ids) + self.position_embeddings(fwd_pos_ids)
        else:
            fwd_pos_ids = torch.tensor(fwd_pos_ids, device=self.device, dtype=torch.long)
            bwd_pos_ids = torch.tensor(bwd_pos_ids, device=self.device, dtype=torch.long)
            input_embedding = embeddings(flatten_ids) + self.position_embeddings(fwd_pos_ids) + \
                self.position_embeddings(bwd_pos_ids)
        mem_pos_ids = torch.tensor(mem_pos_ids, device=self.device, dtype=torch.long)
        memory = memory + self.position_embeddings(mem_pos_ids)
        input = torch.cat([memory, input_embedding], dim=0)  # (total_len, dim)

        gather_indices = torch.tensor(gather_indices, device=self.device, dtype=torch.long)
        for layer in self.layers:
            tgt = input[gather_indices]  # (total_len, max_key_len, dim)
            input = layer(input, tgt, mask)

        return_indices = torch.tensor(return_indices, device=self.device, dtype=torch.long)
        return input[return_indices]