import Layers as L
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class EncoderLayer(nn.Module):
    def __init__(self, d_model, self_attn, feed_forward, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = L.clones(L.SubLayerConnection(d_model, dropout), 2)
        self.d_model = d_model

    def forward(self, x, mask=None):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        x = self.sublayer[1](x, self.feed_forward)
        return x


class Encoder(nn.Module):
    def __init__(self, EncodeLayer, num_of_EncodeLayer: int):
        super(Encoder, self).__init__()
        self.EncodeLayers = L.clones(EncodeLayer, num_of_EncodeLayer)
        self.layerNorm = L.LayerNorm(EncodeLayer.d_model) # ???

    def forward(self, x, mask=None):
        for encoder in self.EncodeLayers:
            x = encoder(x, mask)
            x = self.layerNorm(x)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.d_model = d_model
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = L.clones(L.SubLayerConnection(d_model, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)


class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.decodelayers = L.clones(layer, N)
        self.layerNorm = L.LayerNorm(layer.d_model) # ??

    def forward(self, x, memony, src_mask, tgt_mask):
        for layer in self.decodelayers:
            x = layer(x, memony, src_mask, tgt_mask)
            x = self.layerNorm(x)
        return x


def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0


class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.liner = nn.Linear(d_model, vocab)

    def forward(self, x):
        return F.log_softmax(self.liner(x), dim=-1)


class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        src = self.src_embed(src)
        src = self.encoder(src, src_mask)
        embed_tgt = self.tgt_embed(tgt)
        Decoder_result = self.decoder(embed_tgt, src, src_mask, tgt_mask)
        output = self.generator(Decoder_result)
        return output


def make_model(src_vocab, tgt_vocab, N=2, d_model=512, d_ff=2048, h=8, dropout=0.1):
    c = copy.deepcopy
    attn = L.MultiHeadedAttention(h, d_model)
    ff = L.PositionWiseFeedForward(d_model, d_ff, dropout)
    position = L.Positional_Encoding(d_model, dropout)

    model = EncoderDecoder(Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout),N),
                           Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
                           nn.Sequential(L.embeddings(d_model, src_vocab), c(position)),
                           nn.Sequential(L.embeddings(d_model, tgt_vocab), c(position)),
                           Generator(d_model, tgt_vocab)
                           )
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model
