import torch
import torch.nn as nn
import math
from eos_line_search.models import transformer as tr


class FullTransformer(tr.Transformer):

    def __init__(
        self,
        in_vocab_size,
        out_vocab_size,
        embed_size,
        num_heads,
        num_encode_blocks,
        num_decode_blocks,
        dim_ff,
        max_seq_length,
    ):
        super().__init__()
        self.encode_embed = nn.Embedding(in_vocab_size, embed_size)
        self.decode_embed = nn.Embedding(out_vocab_size, embed_size)
        self.pos_encode = tr.OriginalPositionalEncoding(embed_size, max_seq_length)

        self.encoder_blocks = nn.ModuleList(
            [
                EncoderBlock(embed_size, num_heads, dim_ff)
                for _ in range(num_encode_blocks)
            ]
        )
        self.decoder_blocks = nn.ModuleList(
            [
                DecoderBlock(embed_size, num_heads, dim_ff)
                for _ in range(num_decode_blocks)
            ]
        )

        self.linear = nn.Linear(embed_size, out_vocab_size)

    def generate_mask(self, inp, out):
        in_mask = (inp != 0).unsqueeze(1).unsqueeze(2)
        out_mask = (out != 0).unsqueeze(1).unsqueeze(3)
        seq_length = out.size(1)
        nopeak_mask = (
            1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)
        ).bool()
        out_mask = out_mask & nopeak_mask
        return in_mask, out_mask

    def forward(self, inp, out):
        in_mask, out_mask = self.generate_mask(inp, out)
        in_embedded = self.pos_encode(self.encode_embed(inp))
        out_embedded = self.pos_encode(self.decode_embed(out))

        enc_out = in_embedded
        for enc_layer in self.encoder_blocks:
            enc_out = enc_layer(enc_out, in_mask)

        dec_out = out_embedded
        for dec_layer in self.decoder_blocks:
            dec_out = dec_layer(dec_out, enc_out, in_mask, out_mask)

        out = self.linear(dec_out)
        return out


class PositionWiseFeedForward(nn.Module):

    def __init__(self, embed_size, dim_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.layer1 = nn.Linear(embed_size, dim_ff)
        self.layer2 = nn.Linear(dim_ff, embed_size)
        self.act = nn.Sigmoid()  ### may need to edit this to use different activation

    def forward(self, x):
        out = self.act(self.layer1(x))
        out = self.layer2(out)
        return out


class EncoderBlock(nn.Module):
    def __init__(self, embed_size, num_heads, dim_ff):
        super(EncoderBlock, self).__init__()
        self.self_atten = tr.MultiHeadAttention(embed_size, num_heads)
        self.ff = PositionWiseFeedForward(embed_size, dim_ff)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

    def forward(self, x, mask):
        atten_out = self.self_atten(x, x, x, mask)
        x = self.norm1(x + atten_out)
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, embed_size, num_heads, dim_ff):
        super(DecoderBlock, self).__init__()
        self.self_atten = tr.MultiHeadAttention(embed_size, num_heads)
        self.cross_atten = tr.MultiHeadAttention(embed_size, num_heads)
        self.ff = PositionWiseFeedForward(embed_size, dim_ff)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)

    def forward(self, x, encode_out, in_mask, out_mask):
        atten_out = self.self_atten(x, x, x, out_mask)
        x = self.norm1(x + atten_out)
        atten_out = self.cross_atten(x, encode_out, encode_out, in_mask)
        x = self.norm2(x + atten_out)
        ff_out = self.ff(x)
        x = self.norm3(x + ff_out)
        return x
