import math

import torch
import torch.nn.functional as F
from torch import nn


def get_fixed_sin_cos_encodings(d_model, max_len):
    """
    Sin-cos fixed positional encodddings
    Args:
        d_model: hidden state dimensionality
        max_len: max sequence length
    Returns: PE
    """
    assert d_model % 2 == 0
    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe = torch.zeros(max_len, d_model)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe


class AbsolutePositionalEncoding(nn.Module):
    def __init__(self, max_len: int, d_model: int, trainable=False):
        super().__init__()
        self.max_len = max_len
        self.trainable = trainable
        if trainable:
            self.pe = nn.Embedding(max_len, d_model)
        else:
            self.register_buffer("pe", get_fixed_sin_cos_encodings(d_model, max_len))

    def forward(self, x):
        batch_size = x.size(0)
        actual_len = x.shape[1]
        assert actual_len <= self.max_len

        pe = self.pe.weight if self.trainable else self.pe
        return pe.unsqueeze(0).repeat(batch_size, 1, 1)[:, :actual_len, :]

    def get_pe(self, position):
        pe = self.pe.weight if self.trainable else self.pe
        return pe[position]


class RelativePositionalEncoding(nn.Module):
    def __init__(
        self, max_relative_position: int, d_model: int, trainable=False, cross_attn=False
    ):
        super().__init__()
        self.max_relative_position = max_relative_position
        self.trainable = trainable
        self.cross_attn = cross_attn
        self.num_embeddings = (
            (max_relative_position * 2 + 1) if not cross_attn else (max_relative_position + 1)
        )
        if trainable:
            self.embeddings_table = nn.Embedding(self.num_embeddings, d_model)
        else:
            self.register_buffer(
                "embeddings_table",
                get_fixed_sin_cos_encodings(d_model, max_relative_position * 2 + 1),
            )

    def forward(self, length_q, length_k):
        embeddings_table = (
            self.embeddings_table.weight if self.trainable else self.embeddings_table
        )

        if self.cross_attn:
            distance_mat = (
                torch.arange(length_k - 1, -1, -1)[None, :] + torch.arange(length_q)[:, None]
            )
        else:
            distance_mat = torch.arange(length_k)[None, :] - torch.arange(length_q)[:, None]
        distance_mat_clipped = torch.clamp(
            distance_mat, -self.max_relative_position, self.max_relative_position
        )
        if not self.cross_attn:
            distance_mat_clipped = distance_mat_clipped + self.max_relative_position
        final_mat = torch.LongTensor(distance_mat_clipped)
        embeddings = embeddings_table[final_mat]

        return embeddings


class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.weight * (x - mean) / (std + self.eps) + self.bias


class Attention(nn.Module):
    def __init__(
        self,
        positional_encoding_k: RelativePositionalEncoding = None,
        positional_encoding_v: RelativePositionalEncoding = None,
    ):
        super().__init__()
        self.positional_encoding_k = positional_encoding_k
        self.positional_encoding_v = positional_encoding_v

    def forward(self, query, key, value, mask=None, dropout=None, one_direction=False):
        scores = torch.matmul(query, key.transpose(-2, -1))

        if self.positional_encoding_k is not None:
            R_k = self.positional_encoding_k(query.size(2), key.size(2))
            scores = scores + torch.einsum("b h q d, q k d -> b h q k", query, R_k)

        scores = scores / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        if one_direction:  # Required for self-attention, but not for cross-attention
            direction_mask = torch.ones_like(scores)
            direction_mask = torch.tril(direction_mask)
            scores = scores.masked_fill(direction_mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        output = torch.matmul(p_attn, value)

        if self.positional_encoding_v is not None:
            R_v = self.positional_encoding_v(query.size(2), value.size(2))
            output = output + torch.einsum("b h q v, q v d -> b h q d", p_attn, R_v)

        return output, p_attn


class MultiHeadedAttention(nn.Module):
    def __init__(
        self,
        num_heads,
        d_model,
        head_size=None,
        dropout=0.0,
        positional_encoding_k=None,
        positional_encoding_v=None,
        final_layer=False,
    ):
        super().__init__()
        assert d_model % num_heads == 0

        self.num_heads = num_heads
        if head_size is not None:
            self.head_size = head_size
        else:
            self.head_size = d_model // num_heads

        self.linear_layers = nn.ModuleList(
            [nn.Linear(d_model, self.num_heads * self.head_size) for _ in range(3)]
        )
        self.attention = Attention(positional_encoding_k, positional_encoding_v)
        self.dropout = nn.Dropout(p=dropout)
        if final_layer:
            self.final_layer = nn.Linear(self.num_heads * self.head_size, d_model)
        self.layer_norm = LayerNorm(d_model)

    def forward(self, query, key, value, mask=None, one_direction=True):
        batch_size = query.size(0)

        # 1) do all the linear projections in batch from d_model => num_heads x d_k
        query_, key_, value_ = (
            layer(x).view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)
            for layer, x in zip(self.linear_layers, (query, key, value))
        )
        # 2) apply self_attention on all the projected vectors in batch.
        x, attn = self.attention(
            query_, key_, value_, mask=mask, dropout=self.dropout, one_direction=one_direction
        )

        # 3) "concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_size)
        if hasattr(self, "final_layer"):
            x = self.final_layer(x)

        return self.layer_norm(x + query)


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.layer_norm = LayerNorm(d_model)

    def forward(self, x):
        x_ = self.dropout(self.activation(self.conv1(x.permute(0, 2, 1))))
        return self.layer_norm(self.dropout(self.conv2(x_)).permute(0, 2, 1) + x)


class TransformerEncoderBlock(nn.Module):
    def __init__(
        self,
        hidden,
        attn_heads,
        head_size,
        feed_forward_hidden,
        dropout,
        attn_dropout=0.1,
        self_positional_encoding_k=None,
        self_positional_encoding_v=None,
        final_layer=True,
        **kwargs
    ):
        super().__init__()
        self.self_attention = MultiHeadedAttention(
            num_heads=attn_heads,
            d_model=hidden,
            head_size=head_size,
            dropout=attn_dropout,
            positional_encoding_k=self_positional_encoding_k,
            positional_encoding_v=self_positional_encoding_v,
            final_layer=final_layer,
        )
        self.feed_forward = PositionwiseFeedForward(
            d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout
        )

    def forward(self, x, active_entries):
        self_att_mask = active_entries.repeat(1, 1, active_entries.size(1)).unsqueeze(1)
        x = self.self_attention(x, x, x, self_att_mask, True)
        x = self.feed_forward(x)
        return x


class TransformerDecoderBlock(nn.Module):
    def __init__(
        self,
        hidden,
        attn_heads,
        head_size,
        feed_forward_hidden,
        dropout,
        attn_dropout,
        self_positional_encoding_k=None,
        self_positional_encoding_v=None,
        cross_positional_encoding_k=None,
        cross_positional_encoding_v=None,
        final_layer=False,
        **kwargs
    ):
        super().__init__()
        self.layer_norm = LayerNorm(hidden)
        self.self_attention = MultiHeadedAttention(
            num_heads=attn_heads,
            d_model=hidden,
            head_size=head_size,
            dropout=attn_dropout,
            positional_encoding_k=self_positional_encoding_k,
            positional_encoding_v=self_positional_encoding_v,
            final_layer=final_layer,
        )
        self.cross_attention = MultiHeadedAttention(
            num_heads=attn_heads,
            d_model=hidden,
            head_size=head_size,
            dropout=attn_dropout,
            positional_encoding_k=cross_positional_encoding_k,
            positional_encoding_v=cross_positional_encoding_v,
            final_layer=final_layer,
        )
        self.feed_forward = PositionwiseFeedForward(
            d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout
        )

    def forward(self, x, encoder_x, active_entries, active_encoder_br):
        self_att_mask = active_entries.repeat(1, 1, active_entries.size(1)).unsqueeze(1)
        cross_att_mask = (active_encoder_br.unsqueeze(1) * active_entries).unsqueeze(1)

        x = self.self_attention(x, x, x, self_att_mask, True)
        x = self.cross_attention(x, encoder_x, encoder_x, cross_att_mask, False)
        x = self.feed_forward(x)
        return x


class TransformerMultiInputBlock(nn.Module):
    def __init__(
        self,
        hidden,
        attn_heads,
        head_size,
        feed_forward_hidden,
        dropout,
        attn_dropout,
        self_positional_encoding_k=None,
        self_positional_encoding_v=None,
        n_inputs=2,
        final_layer=False,
        disable_cross_attention=False,
        isolate_subnetwork="",
        **kwargs
    ):
        super().__init__()
        self.n_inputs = n_inputs
        self.disable_cross_attention = disable_cross_attention
        self.isolate_subnetwork = isolate_subnetwork
        self.self_attention_o = MultiHeadedAttention(
            num_heads=attn_heads,
            d_model=hidden,
            head_size=head_size,
            dropout=attn_dropout,
            positional_encoding_k=self_positional_encoding_k,
            positional_encoding_v=self_positional_encoding_v,
            final_layer=final_layer,
        )
        self.self_attention_t = MultiHeadedAttention(
            num_heads=attn_heads,
            d_model=hidden,
            head_size=head_size,
            dropout=attn_dropout,
            positional_encoding_k=self_positional_encoding_k,
            positional_encoding_v=self_positional_encoding_v,
            final_layer=final_layer,
        )
        if not disable_cross_attention:
            self.cross_attention_ot = MultiHeadedAttention(
                num_heads=attn_heads,
                d_model=hidden,
                head_size=head_size,
                dropout=attn_dropout,
                positional_encoding_k=self_positional_encoding_k,
                positional_encoding_v=self_positional_encoding_v,
                final_layer=final_layer,
            )
            self.cross_attention_to = MultiHeadedAttention(
                num_heads=attn_heads,
                d_model=hidden,
                head_size=head_size,
                dropout=attn_dropout,
                positional_encoding_k=self_positional_encoding_k,
                positional_encoding_v=self_positional_encoding_v,
                final_layer=final_layer,
            )

        if n_inputs == 3:
            self.self_attention_v = MultiHeadedAttention(
                num_heads=attn_heads,
                d_model=hidden,
                head_size=head_size,
                dropout=attn_dropout,
                positional_encoding_k=self_positional_encoding_k,
                positional_encoding_v=self_positional_encoding_v,
                final_layer=final_layer,
            )
            if not disable_cross_attention:
                self.cross_attention_tv = MultiHeadedAttention(
                    num_heads=attn_heads,
                    d_model=hidden,
                    head_size=head_size,
                    dropout=attn_dropout,
                    positional_encoding_k=self_positional_encoding_k,
                    positional_encoding_v=self_positional_encoding_v,
                    final_layer=final_layer,
                )
                self.cross_attention_vt = MultiHeadedAttention(
                    num_heads=attn_heads,
                    d_model=hidden,
                    head_size=head_size,
                    dropout=attn_dropout,
                    positional_encoding_k=self_positional_encoding_k,
                    positional_encoding_v=self_positional_encoding_v,
                    final_layer=final_layer,
                )
                self.cross_attention_ov = MultiHeadedAttention(
                    num_heads=attn_heads,
                    d_model=hidden,
                    head_size=head_size,
                    dropout=attn_dropout,
                    positional_encoding_k=self_positional_encoding_k,
                    positional_encoding_v=self_positional_encoding_v,
                    final_layer=final_layer,
                )
                self.cross_attention_vo = MultiHeadedAttention(
                    num_heads=attn_heads,
                    d_model=hidden,
                    head_size=head_size,
                    dropout=attn_dropout,
                    positional_encoding_k=self_positional_encoding_k,
                    positional_encoding_v=self_positional_encoding_v,
                    final_layer=final_layer,
                )

        self.feed_forwards = [
            PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
            for _ in range(n_inputs)
        ]
        self.feed_forwards = nn.ModuleList(self.feed_forwards)

        self.n_inputs = n_inputs

    def forward(self, x_tov, x_s, active_entries_treat_outcomes, active_entries_vitals=None):
        assert len(x_tov) == self.n_inputs
        if self.n_inputs == 2:
            x_t, x_o = x_tov
        else:
            x_t, x_o, x_v = x_tov

        self_att_mask_ot = active_entries_treat_outcomes.repeat(1, 1, x_t.size(1)).unsqueeze(1)
        cross_att_mask_ot = cross_att_mask_to = self_att_mask_ot

        x_t_ = self.self_attention_t(x_t, x_t, x_t, self_att_mask_ot, True)
        x_to_ = (
            self.cross_attention_to(x_t_, x_o, x_o, cross_att_mask_ot, True)
            if not self.disable_cross_attention
            and self.isolate_subnetwork != "t"
            and self.isolate_subnetwork != "o"
            else x_t_
        )

        x_o_ = self.self_attention_o(x_o, x_o, x_o, self_att_mask_ot, True)
        x_ot_ = (
            self.cross_attention_ot(x_o_, x_t, x_t, cross_att_mask_to, True)
            if not self.disable_cross_attention
            and self.isolate_subnetwork != "o"
            and self.isolate_subnetwork != "t"
            else x_o_
        )

        if self.n_inputs == 2:
            if x_s is not None:
                out_t = self.feed_forwards[0](x_to_)
                out_o = self.feed_forwards[1](x_ot_ + x_s)
            else:
                out_t = self.feed_forwards[0](x_to_)
                out_o = self.feed_forwards[1](x_ot_)

            return out_t, out_o

        else:
            self_att_mask_v = active_entries_vitals.repeat(1, 1, x_v.size(1)).unsqueeze(1)
            cross_att_mask_ot_v = (
                active_entries_vitals.squeeze(-1).unsqueeze(1) * active_entries_treat_outcomes
            ).unsqueeze(1)
            cross_att_mask_v_ot = (
                active_entries_treat_outcomes.squeeze(-1).unsqueeze(1) * active_entries_vitals
            ).unsqueeze(1)

            x_tv_ = (
                self.cross_attention_to(x_t_, x_v, x_v, cross_att_mask_ot_v, True)
                if not self.disable_cross_attention
                and self.isolate_subnetwork != "t"
                and self.isolate_subnetwork != "v"
                else 0.0
            )
            x_ov_ = (
                self.cross_attention_to(x_o_, x_v, x_v, cross_att_mask_ot_v, True)
                if not self.disable_cross_attention
                and self.isolate_subnetwork != "o"
                and self.isolate_subnetwork != "v"
                else 0.0
            )

            x_v_ = self.self_attention_o(x_v, x_v, x_v, self_att_mask_v, True)
            x_vt_ = (
                self.cross_attention_ot(x_v_, x_t, x_t, cross_att_mask_v_ot, True)
                if not self.disable_cross_attention
                and self.isolate_subnetwork != "v"
                and self.isolate_subnetwork != "t"
                else x_v_
            )
            x_vo_ = (
                self.cross_attention_ot(x_v_, x_o, x_o, cross_att_mask_v_ot, True)
                if not self.disable_cross_attention
                and self.isolate_subnetwork != "v"
                and self.isolate_subnetwork != "o"
                else 0.0
            )
            if x_s is not None:
                out_t = self.feed_forwards[0](x_to_ + x_tv_)
                out_o = self.feed_forwards[1](x_ot_ + x_ov_ + x_s)
                out_v = self.feed_forwards[2](x_vt_ + x_vo_)
            else:
                out_t = self.feed_forwards[0](x_to_ + x_tv_)
                out_o = self.feed_forwards[1](x_ot_ + x_ov_)
                out_v = self.feed_forwards[2](x_vt_ + x_vo_)

            return out_t, out_o, out_v
