import torch as t
from torch import Tensor
import torch.nn as nn
from torch.nn import MultiheadAttention, Linear, Dropout, LayerNorm
import torch.nn.functional as F
import itertools as it




device = "cuda" if t.cuda.is_available() else "cpu"


import math
class PositionalEncoding(nn.Module): #https://pytorch.org/tutorials/beginner/transformer_tutorial.html

    def __init__(self, d_model, scale=1.0, dropout=0.1, max_len=16384):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = t.zeros(max_len, d_model)
        position = t.arange(0, max_len, dtype=t.float).unsqueeze(1)
        div_term = t.exp(t.arange(0, d_model, 2).float() * (-math.log(25000.0) / d_model))
        pe[:, 0::2] = t.sin(position * div_term) * scale
        pe[:, 1::2] = t.cos(position * div_term) * scale
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(1), :]
        return self.dropout(x)

    def embed(self, x):
        return self.dropout(self.pe[x])

    def unembed(self, x):
        return x - self.pe[:x.size(1), :]

    def embed2d(self, y, x):
        ey = self.pe[y]
        ex = self.pe[x]
        e = t.cat([ey, ex], dim=-1)
        return self.dropout(e)


class LearnablePositionalEncoding(nn.Module):

    def __init__(self, d_model, period=10000.0, scale=1.0, dropout=0.1, max_len=16384, init="sine", embed_style="1d", learnable=True):
        super(LearnablePositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)



        if init == "sine":
            pe = t.zeros(max_len, d_model)
            position = t.arange(0, max_len, dtype=t.float).unsqueeze(1)
            div_term = t.exp(t.arange(0, d_model, 2).float() * (-math.log(period) / d_model))
            pe[:, 0::2] = t.sin(position * div_term) * scale
            pe[:, 1::2] = t.cos(position * div_term) * scale
        else:
            pe = t.randn(max_len, d_model) * scale
        if learnable:
            self.pe = nn.Parameter(pe)
        else:
            self.register_buffer("pe", pe)
        self.embed_style = embed_style



    def forward(self, x, negative=False):
        if negative:
            x = x - self.pe[:x.size(1), :]
        else:
            x = x + self.pe[:x.size(1), :]
        return self.dropout(x)

    def embed(self, x):
        return self.dropout(self.pe[x])

    def embed2d(self, y, x):
        ey = self.pe[y]
        ex = self.pe[x]
        e = t.cat([ey, ex], dim=-1)
        return self.dropout(e)




#all following code is based off standard pytorch implementation

class IndependentTransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", special_mode={}):
        super(IndependentTransformerDecoderLayer, self).__init__()
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)
        self.special_mode = special_mode
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout2 = Dropout(dropout)
        self.dropout3 = Dropout(dropout)
        if self.special_mode.get("fixed_key"):
            max_len, scale, mode = special_mode.get("fixed_key")
            if mode == "sine":
                pe = t.zeros(max_len, d_model)
                position = t.arange(0, max_len, dtype=t.float).unsqueeze(1)
                div_term = t.exp(t.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
                pe[:, 0::2] = t.sin(position * div_term) * scale
                pe[:, 1::2] = t.cos(position * div_term) * scale
            else:
                pe = t.randn(max_len, d_model) * scale

            self.offset = nn.Parameter(pe)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(IndependentTransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt: Tensor, memory: Tensor, **kwargs) -> Tensor:
        if self.special_mode.get("fixed_key"):
            batch_size = memory.shape[1]
            offset = self.offset.unsqueeze(1).repeat(1, batch_size, 1)
            tgt2 = self.multihead_attn(tgt, offset, memory)[0]
        else:
            tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu


class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", special_mode={}):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)
        self.special_mode = special_mode
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.dropout3 = Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt: Tensor, memory: Tensor, **kwargs) -> Tensor:

        tgt2 = self.self_attn(tgt, tgt, tgt)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2 = self.multihead_attn(tgt, memory, memory)[0]
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt


class FixedKeyMultiheadAttention(nn.Module):

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
        self.k_proj_weight = torch.eye(embed_dim)
        self.register_buffer('k_proj_weight', self.k_proj_weight)
        self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
        self.register_parameter('in_proj_weight', None)
        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = _LinearWithBias(embed_dim, embed_dim)

        if add_bias_kv:
            self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
        else:
            self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        xavier_uniform_(self.q_proj_weight)
        xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        return
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super(MultiheadAttention, self).__setstate__(state)

    def forward(self, query, key, value, key_padding_mask = None,
                need_weights = True, attn_mask = None):
        return F.multi_head_attention_forward(
            query, key, value, self.embed_dim, self.num_heads,
            self.in_proj_weight, self.in_proj_bias,
            self.bias_k, self.bias_v, self.add_zero_attn,
            self.dropout, self.out_proj.weight, self.out_proj.bias,
            training=self.training,
            key_padding_mask=key_padding_mask, need_weights=need_weights,
            attn_mask=attn_mask, use_separate_proj_weight=True,
            q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
            v_proj_weight=self.v_proj_weight)