
import torch
import torch.nn as nn
import math
import json
from torch.utils.checkpoint import checkpoint
import sys
import os
from torch import Tensor
import numpy as np
from typing import Optional, Tuple
import torch.nn.functional as F
from model import Embeddings

curr_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(curr_path)

class SoftmaxAttention(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        self.head_dim = head_dim

    def forward(self, Q, K, V, mask):
        dot = torch.matmul(Q, torch.transpose(K, -2, -1))
        dot = dot / math.sqrt(self.head_dim)
        if mask is not None:
            dot = dot - 1e6 * (1 - mask[:, None, :])

        attn = nn.functional.softmax(dot, dim = -1)

        X = torch.matmul(attn, V)
        return X

class MraDotProductAttention(nn.Module):
    r"""

    Args: dim, mask
        dim (int): dimension of attention
        mask (torch.Tensor): tensor containing indices to be masked

    Inputs: query, key, value, mask
        - **query** (batch, q_len, d_model): tensor containing projection vector for decoders.
        - **key** (batch, k_len, d_model): tensor containing projection vector for encoders.
        - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
        - **mask** (-): tensor containing indices to be masked

    Returns: context, attn
        - **context**: tensor containing the context vector from attention mechanism.
        - **attn**: tensor containing the attention (alignment) from the encoders outputs.
    """
    def __init__(self, dim: int, scale: bool = True, group_by_list = [1, 2], downsampling_mode=None) -> None:
        super(MraDotProductAttention, self).__init__()
        self.group_by_list = group_by_list
        self.head_dim = int(dim / len(group_by_list))
        self.attn = SoftmaxAttention(self.head_dim)
        self.downsampling_mode = downsampling_mode
        if scale:
            self.sqrt_dim = np.sqrt(dim)
        else:
            self.sqrt_dim = 1

    def forward(
            self,
            query: torch.FloatTensor,
            key: torch.FloatTensor,
            value: torch.FloatTensor,
            mask: Optional[torch.FloatTensor] = None,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        # query: [bsz, n_head, seq_len, hidden_dim]; mask: [bsz, seq_len]
        if mask is not None:
            # query = query * mask[:, :, None]
            key = key * mask[:, None, :, None]
            value = value * mask[:, None, :, None]

        attn_out = torch.empty(query.shape[0], query.shape[1], query.shape[2], value.shape[3], device=query.device)
        for h in range(query.shape[1]):
            # Down sampling mask and input
            if mask is not None:
                mask_ = torch.clip(mask.reshape(key.shape[0], key.shape[2] // self.group_by_list[h], self.group_by_list[h]).sum(dim=-1), min=0, max=1)
                token_count = mask.reshape(key.shape[0], key.shape[2] // self.group_by_list[h], self.group_by_list[h]).sum(dim = -1)
                if self.downsampling_mode == "q":
                    q_ = query[:, h, :, :].reshape(query.shape[0], query.shape[2] // self.group_by_list[h], self.group_by_list[h], query.shape[3]).sum(dim = -2) / (self.group_by_list[h] + 1e-6)
                    k_ = key[:, h, :, :]
                    v_ = value[:, h, :, :]
                elif self.downsampling_mode == "kv":
                    q_ = query[:, h, :, :]
                    k_ = key[:, h, :, :].reshape(key.shape[0], key.shape[2] // self.group_by_list[h], self.group_by_list[h], key.shape[3]).sum(dim = -2) / (token_count[:, :, None] + 1e-6)
                    v_ = value[:, h, :, :].reshape(value.shape[0], value.shape[2] // self.group_by_list[h], self.group_by_list[h], value.shape[3]).sum(dim = -2) / (token_count[:, :, None] + 1e-6)
            else:
                mask_ = None
                if self.downsampling_mode == "q":
                    q_ = query[:, h, :, :].reshape(query.shape[0], query.shape[2] // self.group_by_list[h], self.group_by_list[h], query.shape[3]).sum(dim = -2) / (self.group_by_list[h] + 1e-6)
                    k_ = key[:, h, :, :]
                    v_ = value[:, h, :, :]
                elif self.downsampling_mode == "kv":
                    q_ = query[:, h, :, :]
                    k_ = key[:, h, :, :].reshape(key.shape[0], key.shape[2] // self.group_by_list[h], self.group_by_list[h], key.shape[3]).sum(dim = -2) / (self.group_by_list[h] + 1e-6)
                    v_ = value[:, h, :, :].reshape(value.shape[0], value.shape[2] // self.group_by_list[h], self.group_by_list[h], value.shape[3]).sum(dim = -2) / (self.group_by_list[h] + 1e-6)

            attn_out_ = self.attn(q_, k_, v_, mask_)
            if self.downsampling_mode == "q":
                attn_out[:, h, :, :] = attn_out_.squeeze(1).repeat_interleave(self.group_by_list[h], dim=1)
            else:
                attn_out[:, h, :, :] = attn_out_.squeeze(1)
        context = self.combine_heads(attn_out)
        
        return context, None

    def combine_heads(self, X):
        X = X.transpose(1, 2)
        X = X.reshape(X.size(0), X.size(1), X.shape[-2] * self.head_dim)
        return X

class MultiHeadAttention(nn.Module):
    r"""
    Multi-Head Attention proposed in "Attention Is All You Need"
    Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
    project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
    These are concatenated and once again projected, resulting in the final values.
    Multi-head attention allows the model to jointly attend to information from different representation
    subspaces at different positions.

    MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
        where head_i = Attention(Q · W_q, K · W_k, V · W_v)

    Args:
        dim (int): The dimension of model (default: 512)
        num_attention_heads (int): The number of attention heads. (default: 8)

    Inputs: query, key, value, mask
        - **query** (batch, q_len, d_model): tensor containing projection vector for decoders.
        - **key** (batch, k_len, d_model): tensor containing projection vector for encoders.
        - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
        - **mask** (-): tensor containing indices to be masked

    Returns: output, attn
        - **output** (batch, output_len, dimensions): tensor containing the attended output features.
        - **attn** (batch * num_attention_heads, v_len): tensor containing the attention (alignment) from the encoders outputs.
    """
    def __init__(self, dim: int = 512, num_attention_heads: int = 8, group_by_list=[1, 2], downsampling_mode = None) -> None:
        super(MultiHeadAttention, self).__init__()

        assert dim % num_attention_heads == 0, "hidden_dim % num_attention_heads should be zero."

        self.d_head = int(dim / num_attention_heads)
        self.num_attention_heads = num_attention_heads
        self.query_proj = nn.Linear(dim, self.d_head * num_attention_heads)
        self.key_proj = nn.Linear(dim, self.d_head * num_attention_heads)
        self.value_proj = nn.Linear(dim, self.d_head * num_attention_heads)
        self.scaled_dot_attn = MraDotProductAttention(dim, scale=True, group_by_list=group_by_list, downsampling_mode= downsampling_mode)
        self.group_by_list = group_by_list

    def forward(
            self,
            query: torch.FloatTensor,
            key: torch.FloatTensor,
            value: torch.FloatTensor,
            mask: Optional[torch.FloatTensor] = None,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        batch_size = value.size(0)

        query = self.query_proj(query).view(batch_size, -1, self.num_attention_heads, self.d_head).transpose(1, 2)
        key = self.key_proj(key).view(batch_size, -1, self.num_attention_heads, self.d_head).transpose(1, 2)
        value = self.value_proj(value).view(batch_size, -1, self.num_attention_heads, self.d_head).transpose(1, 2)

        # [bsz, n_heads, seq_len, dim]
        # if mask is not None:
        #     mask = mask.unsqueeze(1).repeat(1, self.num_attention_heads, 1, 1)

        context, attn = self.scaled_dot_attn(query, key, value, mask)

        context = context.transpose(1, 2).reshape(batch_size, -1, self.num_attention_heads * self.d_head)

        return context, attn


class LinearUnifiedNestedAttention(nn.Module):
    def __init__(self, dim, group_by_list, num_attention_heads: int = 2) -> None:
        super(LinearUnifiedNestedAttention, self).__init__()
        self.group_by_list=group_by_list
        self.pack_attention = MultiHeadAttention(dim, num_attention_heads, self.group_by_list, downsampling_mode = "kv")
        self.unpack_attention = MultiHeadAttention(dim, num_attention_heads, self.group_by_list, downsampling_mode = "q")
        

    def forward(
            self,
            query: torch.FloatTensor,
            key: torch.FloatTensor,
            value: torch.FloatTensor,
            p: torch.FloatTensor,
            attention_padding_mask: torch.BoolTensor = None,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        packed_context, _ = self.pack_attention(p, key, value, attention_padding_mask)
        unpacked_context, _ = self.unpack_attention(query, packed_context, packed_context)
        return unpacked_context, packed_context

class LunaTransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.luna_attention = LinearUnifiedNestedAttention(config["dim"], config["group_by_list"], config["num_head"])
        self.feed_forward = nn.Sequential(
            nn.Linear(config["dim"], config["hidden_dim"]),
            nn.GELU(),
            torch.nn.Dropout(p = config["dropout_prob"]),
            nn.Linear(config["hidden_dim"], config["dim"]),
            torch.nn.Dropout(p = config["dropout_prob"])
        )
        self.packed_context_layer_norm = nn.LayerNorm(config["dim"])
        self.unpacked_context_layer_norm = nn.LayerNorm(config["dim"])
        self.unpacked_context_layer_norm = nn.LayerNorm(config["dim"])
        self.feed_forward_layer_norm = nn.LayerNorm(config["dim"])


    def forward(self, inputs, p, mask):
        # X = self.dropout1(self.mha(self.norm1(X), mask)) + X
        # X = self.mlpblock(self.norm2(X)) + X
        unpacked_context, packed_context = self.luna_attention(
            query=inputs,
            key=inputs,
            value=inputs,
            p=p,
            attention_padding_mask=mask,
        )
        packed_context = self.packed_context_layer_norm(packed_context + p)
        unpacked_context = self.unpacked_context_layer_norm(unpacked_context + inputs)

        outputs = self.feed_forward(unpacked_context)
        outputs = self.feed_forward_layer_norm(outputs + unpacked_context)
        return outputs, packed_context

class Backbone(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.num_layers = config["num_layers"]
        self.shared_weight = config["shared_weight"]
        
        self.d_model = config["dim"]
        self.projected_embedding_length = config["project_embedding_length"]

        # self.projected_embeddings = nn.Parameter(torch.Tensor(self.projected_embedding_length, self.d_model))
        # nn.init.normal_(self.projected_embeddings, mean=0.0, std=self.d_model ** -0.5)

        # self.projected_positions = PositionalEncoding(self.d_model, self.projected_embedding_length)
        # self.position_embeddings = nn.Embedding(config["max_seq_len"], config["embedding_dim"])
        # torch.nn.init.normal_(self.position_embeddings.weight, std = 0.02)

        # self.projected_positions = nn.Embedding(self.projected_embedding_length, config["embedding_dim"])
        # torch.nn.init.normal_(self.projected_positions.weight, std = 0.02)
        config_projection = {"embedding_dim": config["embedding_dim"], "dim": config["dim"], \
                             "max_seq_len": self.projected_embedding_length, "vocab_size": self.projected_embedding_length,
                             "model_type": "luna_transformer"}
        self.project_embedding = Embeddings(config_projection)

        # self.input_embedding = nn.Embedding(config["vocab_size"], config["embedding_dim"])
        self.dropout = nn.Dropout(p=config["dropout_prob"])
        # self.input_positions = PositionalEncoding(self.d_model, config["max_seq_len"])
        # self.input_norm = nn.LayerNorm(self.d_model)
        # self.embed_scale = math.sqrt(self.d_model)

        self.encoders = nn.ModuleList([LunaTransformerEncoderLayer(config) for _ in range(self.num_layers)])

        self.norm = nn.LayerNorm(config["dim"])

    def forward(self, X, mask):
        batch_size, seq_length, dim = X.size()

        # embedded = self.input_embedding(X)

        # embedded *= self.embed_scale
        # projected_embedded = self.projected_embeddings * self.embed_scale
        embedded = X
        # embedded += self.input_positions(embedded.size(1))
        # projected_embedded = self.projected_embeddings
        position_ids = torch.arange(self.projected_embedding_length, dtype = torch.long, device = X.device)[None, :]
        # projected_embedded += self.projected_positions(self.projected_embedding_length).squeeze(0)
        # position_embedded = self.projected_positions(position_ids)
        # projected_embedded = self.projected_embeddings + position_embedded
        projected_embedded = self.project_embedding(position_ids).squeeze()

        # projected_embedded += 
        seq_length, dim = projected_embedded.size()
        projected_embedded = projected_embedded.unsqueeze(0).expand(batch_size, seq_length, dim)
        outputs = self.dropout(embedded)
        p = self.dropout(projected_embedded)

        for encoder in self.encoders:
            outputs, p = encoder(outputs, p, mask)
            # X = encoder(X, mask)

        outputs = self.norm(outputs) * mask[:, :, None]

        return outputs

class PositionalEncoding(nn.Module):
    """
    Positional Encoding proposed in "Attention Is All You Need".
    Since transformer contains no recurrence and no convolution, in order for the model to make
    use of the order of the sequence, we must add some positional information.

    "Attention Is All You Need" use sine and cosine functions of different frequencies:
        PE_(pos, 2i)    =  sin(pos / power(10000, 2i / d_model))
        PE_(pos, 2i+1)  =  cos(pos / power(10000, 2i / d_model))
    """
    def __init__(self, d_model: int = 80, max_length: int = 5000) -> None:
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_length, d_model, requires_grad=False)
        position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, length: int) -> Tensor:
        return self.pe[:, :length]