#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, Tuple
import torch
import torch.nn as nn
from .mha_utils import MultiheadAttentionStable
import copy
import math


# Transformer encoder layer with only self-attention
class SelfAttentionEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model=1736,
        nhead=8,
        dim_feedforward=2048,
        dropout=0.1,
        activation=nn.ReLU(inplace=True),
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = activation

    def forward(
        self,
        src: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # local attention
        x = src
        x2 = self._sa_block(x, attn_mask, src_key_padding_mask)
        x = self.norm1(x + x2)
        x = self.norm2(x + self._ff_block(x))

        return x

    # self-attention block
    def _sa_block(
        self,
        x: torch.Tensor,
        attn_mask: torch.Tensor,
        key_padding_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x, _ = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
        return self.dropout1(x)

    # feed forward block
    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)


# Transformer encoder layer with cross-attention and optional self-attention
class CrossAttentionEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model=1736,
        d_kv=256,
        nhead=8,
        dim_feedforward=2048,
        dropout=0.1,
        activation=nn.ReLU(inplace=True),
        q_sa=False,
        q_ffn=False,
        kv_sa=True,
        kv_ffn=False,
    ):
        super().__init__()
        self.nhead = nhead
        self.q_sa = q_sa
        self.q_ffn = q_ffn
        self.kv_sa = kv_sa
        self.kv_ffn = kv_ffn

        # self attention key & value
        if kv_sa:
            self.self_attn_kv = nn.MultiheadAttention(d_kv, nhead, dropout=dropout)
            self.dropout_kv = nn.Dropout(dropout)
            self.norm_kv = nn.LayerNorm(d_kv)
            # key & value FFN if required
            if self.kv_ffn:
                self.linear1_kv = nn.Linear(d_kv, dim_feedforward)
                self.dropout_kv_l1 = nn.Dropout(dropout)
                self.linear2_kv = nn.Linear(dim_feedforward, d_kv)
                self.dropout_kv_l2 = nn.Dropout(dropout)
                self.norm_kv_ffn = nn.LayerNorm(d_kv)

        # query may come from a SA layer, so no SA needed here
        if q_sa:
            # self attention query
            self.self_attn_q = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
            self.dropout_q = nn.Dropout(dropout)
            self.norm_q = nn.LayerNorm(d_model)
            # query FFN if required
            if self.q_ffn:
                self.linear1_q = nn.Linear(d_model, dim_feedforward)
                self.dropout_q_l1 = nn.Dropout(dropout)
                self.linear2_q = nn.Linear(dim_feedforward, d_model)
                self.dropout_q_l2 = nn.Dropout(dropout)
                self.norm_q_ffn = nn.LayerNorm(d_model)

        # cross attention
        self.cross_attn = MultiheadAttentionStable(d_model, nhead, dropout=dropout, kdim=d_kv, vdim=d_kv)
        self.dropout_ca = nn.Dropout(dropout)
        self.norm_ca = nn.LayerNorm(d_model)
        # cross attention FFN
        self.linear1_ca = nn.Linear(d_model, dim_feedforward)
        self.dropout_ca_l1 = nn.Dropout(dropout)
        self.linear2_ca = nn.Linear(dim_feedforward, d_model)
        self.dropout_ca_l2 = nn.Dropout(dropout)
        self.norm_ca_ffn = nn.LayerNorm(d_model)

        self.activation = activation

    def forward(self, q, kv, attn_mask, q_padding_mask, kv_padding_mask):
        """

        Args:
            q (torch.Tensor): shape (L, N, qdim)
            kv (torch.Tensor): shape (S, N, kvdim)
            q_padding_mask (torch.Tensor): shape (N, L)
            kv_padding_mask (torch.Tensor): shape (N, S)

        Returns:
            _type_: _description_
        """
        if self.kv_sa:
            # key & value self-attention
            kv_sa, _ = self.self_attn_kv(
                kv, kv, kv, attn_mask=attn_mask, key_padding_mask=kv_padding_mask, need_weights=False
            )
            kv_sa = self.dropout_kv(kv_sa)
            kv = self.norm_kv(kv + kv_sa)
            if self.kv_ffn:
                kv_ffn = self.linear1_kv(kv)
                kv_ffn = self.activation(kv_ffn)
                kv_ffn = self.dropout_kv_l1(kv_ffn)
                kv_ffn = self.linear2_kv(kv_ffn)
                kv_ffn = self.dropout_kv_l2(kv_ffn)
                kv = self.norm_kv_ffn(kv + kv_ffn)

        if self.q_sa:
            # query self-attention
            q_sa, _ = self.self_attn_q(q, q, q, attn_mask=attn_mask, key_padding_mask=q_padding_mask, need_weights=False)
            q_sa = self.dropout_q(q_sa)
            q = self.norm_q(q + q_sa)
            if self.q_ffn:
                q_ffn = self.linear1_q(q)
                q_ffn = self.activation(q_ffn)
                q_ffn = self.dropout_q_l1(q_ffn)
                q_ffn = self.linear2_q(q_ffn)
                q_ffn = self.dropout_q_l2(q_ffn)
                q = self.norm_q_ffn(q + q_ffn)

        # cross-attention
        # expand q_padding_mask (N, L) to attn_mask (Nxnhead, L, S)
        attn_mask = (
            q_padding_mask.unsqueeze(2).repeat(1, 1, kv_padding_mask.shape[1]).repeat_interleave(self.nhead, dim=0)
        )
        # key_padding_mask (N, S) same as function input
        q_ca, cross_attn_weights = self.cross_attn(
            q, kv, kv, attn_mask=attn_mask, key_padding_mask=kv_padding_mask, need_weights=True
        )
        q_ca = self.dropout_ca(q_ca)
        q = self.norm_ca(q + q_ca)

        q_ca_ffn = self.linear1_ca(q)
        q_ca_ffn = self.activation(q_ca_ffn)
        q_ca_ffn = self.dropout_ca_l1(q_ca_ffn)
        q_ca_ffn = self.linear2_ca(q_ca_ffn)
        q_ca_ffn = self.dropout_ca_l2(q_ca_ffn)
        q = self.norm_ca_ffn(q + q_ca_ffn)

        return q, kv, cross_attn_weights

# Spatial encoder, for all pairs in one frame
class SATranformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers

    # spatial encoder no positional encoding
    def forward(self, src, attn_mask=None,  src_key_padding_mask=None):
        x = src
        for layer in self.layers:
            x = layer(x, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask)
        return x

# Cross-attention encoder
class CrossEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers

    def forward(self, q, kv, attn_mask=None, q_padding_mask=None, kv_padding_mask=None):
        cross_attn_weights = torch.zeros((self.num_layers, q.shape[1], q.shape[0], kv.shape[0])).to(q.device)

        for i, layer in enumerate(self.layers):
            # kv self-attention every layer
            q, kv, ca = layer(q, kv, attn_mask=attn_mask, q_padding_mask=q_padding_mask, kv_padding_mask=kv_padding_mask)
            cross_attn_weights[i] = ca

        if self.num_layers > 0:
            return q, kv, cross_attn_weights
        else:
            return q, kv, None



def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
