import math
import copy
import torch

import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

from torch.autograd import Variable
from torch.optim.lr_scheduler import LambdaLR
from torch.nn.functional import log_softmax
from pyitcast.transformer_utils import Batch, get_std_opt
from pyitcast.transformer_utils import run_epoch, greedy_decode
from pyitcast.transformer_utils import SimpleLossCompute, LabelSmoothing


#   define the txt embedding layer
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        """
        :param d_model: dimension of the word embedding
        :param vocab: the size of the vocabulary
        """
        super(Embeddings, self).__init__()  # use the super function introduce inherit nn.Module's initialize function
        self.lut = nn.Embedding(vocab, d_model)  # get the word embedding object self.lut
        self.d_model = d_model

    def forward(self, x):
        """
        :param x: tensor of the after mapping vocabulary from input txt
        :return:
        """
        return self.lut(x) * math.sqrt(self.d_model)


class PositionalEncoding(nn.Module):
    """Implement the PE function."""

    def __init__(self, d_model, dropout, max_len=5000):
        """
        :param d_model: word embedding dimension
        :param dropout: set zero rate
        :param max_len: sentence's max_length
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)  # positional encoding matrix
        position = torch.arange(0, max_len).unsqueeze(1)  # absolute positional matrix

        # add the abs position information to positional encoding matrix
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)  # initialize a transform matrix
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # extend the dimension of pe 2 to 3

        self.register_buffer("pe", pe)  # not need update with the optimize step

    def forward(self, x):
        """
        :param x: output from embedding layer
        :return: word embedding tensor with position encode info
        """
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)


# d_model = 512
# vocab = 1000
# dropout = 0.1
# max_len = 60
#
# x = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))
# emb = Embeddings(d_model, vocab)
# embr = emb(x)
# x = embr
# pe = PositionalEncoding(d_model, dropout, max_len)
# pe_result = pe(x)
# print(pe)
# print(pe_result)


def subsequent_mask(size):
    """
    :param size: the last two dimension of the mask tensor
    :return: lower triangular mask matrix
    """
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')  # k = 1, 0, -1; mid up, mid, mid down
    return subsequent_mask == 0


# plt.figure(figsize=(5, 5))
# plt.imshow(subsequent_mask(20)[0])
# plt.show()

def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1) / math.sqrt(d_k))  # Q * K^T / sqrt(d_k)

    # mask attention part
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # softmax result
    p_attn = scores.softmax(dim=-1)

    # whether you use the dropout layer or not
    if dropout is not None:
        p_attn = dropout(p_attn)

    return torch.matmul(p_attn, value), p_attn  # softmax * V


def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class MultiHeadedAttention(nn.Module):
    def __init__(self, head, d_model, dropout=0.1):
        """
        :param head: num of the head
        :param d_model: word embedding dimension
        :param dropout: set zero rate
        """
        super(MultiHeadedAttention, self).__init__()

        assert d_model % head == 0  # confirm head % d_model
        self.d_k = d_model // head  # get every head gained the dimension of word embedding

        self.head = head
        self.d_model = d_model
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)

        batch_size = query.size(0)

        query, key, value = [
            model(x).view(batch_size, -1, self.head, self.d_k).transpose(1, 2)
            for model, x in zip(self.linears, (query, key, value))
        ]

        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.head * self.d_k)

        return self.linears[-1](x)

# d_model = 512
# vocab = 1000
# dropout = 0.1
# max_len = 60
#
# x = Variable(torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]]))
# emb = Embeddings(d_model, vocab)
# embr = emb(x)
# x = embr
# pe = PositionalEncoding(d_model, dropout, max_len)
# pe_result = pe(x)
# print(pe_result)
#
# head = 8
# embedding_dim = 512
# mask = Variable(torch.zeros(8, 4, 4))
# query = key = value = pe_result
# mha = MultiHeadedAttention(head, embedding_dim, dropout)
# mha_result = mha(query, key, value, mask)


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        """
        :param d_model: the input dimension of the linear layer
        :param d_ff: the linear layer number
        :param dropout: set zero rate
        """
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)  # instantiate a linear layer class
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)  # instantiate a dropout layer class

    # two linear layers, activate function: RELU
    def forward(self, x):
        # x: from the pre layer's output
        return self.w_2(self.dropout(self.w_1(x).relu()))


class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        """
        :param features: word embedding dimension
        :param eps: avoid denominator equal to 0
        """
        super(LayerNorm, self).__init__()

        # model parameter， initialize two tensor a_2 full 1 metrix, b_2 full 0 metrix, the initial parameter in the
        # normal layer. Using nn.Parameter encapsulate,
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        nx = torch.tensor(x, dtype=torch.float64)
        mean = torch.mean(nx, -1, keepdim=True)  # -1 symbol the last dimension
        mean = torch.repeat_interleave(mean, repeats=80, dim=1)
        std = torch.std(nx, -1, keepdim=True)
        std = torch.repeat_interleave(std, repeats=80, dim=1)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2  # * dot multi


class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        """
        Apply residual connection to any sublayer with the same size
        :param x: tensor from last layer output
        :param sublayer: sublayer function
        :return: x plus after calculate by sublayer function
        """
        return x + self.dropout(sublayer(self.norm(x)))


class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        """
        Encoder is made up of self_attn and feed_forward
        :param size: word embedding dimension
        :param self_attn: instantiate self_attn object, self_attn = MultiHeadedAttention(head, d_model)
        :param feed_forward: instantiate feed_forward object
        :param dropout: dropout layer
        """
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.dropout = dropout
        self.sublayer = clones(SublayerConnection(size, dropout), 2)  # two residual connection structure in the encoder
        self.size = size

    def forward(self, x, mask):
        # realize self attention by lambda key word
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))  # first sublayer MultiHeadedAttention
        return self.sublayer[1](x, self.feed_forward)  # second sublayer feed forward layer


class Encoder(nn.Module):
    def __init__(self, layer, N):
        """
        :param layer: the layer num of the encoder
        :param N: the num of the encoder layer
        """
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        """for input x, be processed at each encoder layer"""
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)


# # instantiate above class and their parameter
# size = 512
# head = 8
# d_model = 512
# d_ff = 64
# c = copy.deepcopy()
# attn = MultiHeadedAttention(head, d_model)
# pos = PositionwiseFeedForward(d_model, d_ff, dropout=0.2)
# layer = EncoderLayer(size, c(attn), c(pos), dropout=0.2)
# N = 8
# mask = Variable(torch.zeros(8, 4, 4))
#
# # invoke
# en = Encoder(layer, N)
# en_result = en(x, mask)

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        """
        :param size: word embedding dimension and the size of the decoder layer
        :param self_attn: self attention, MultiHeaderAttention object, Q=K=V
        :param src_attn: attention, MultiHeaderAttention object, Q!=K=V
        :param feed_forward: feed_forward
        :param dropout: set zero rate
        """
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        """
        :param x: output from the last layer
        :param memory: semantic store variables from encoder layer
        :param src_mask: source data mask tensor
        :param tgt_mask: target data mask tensor
        :return:
        """
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        # common attention Q=x, K=V=m mask some information that is pointless for result,
        # and will improve expression of the model and training speed
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)


class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        # embedding of target data
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)


class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        """
        Define standard linear and softmax generation step.
        :param d_model: dimension of the word embedding
        :param vocab: the size of the vocabulary
        """
        super(Generator, self).__init__()
        self.project = nn.Linear(d_model, vocab)

    def forward(self, x):
        return log_softmax(self.project(x), dim=-1)


class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        """
        A standard Encoder-Decoder architecture, base other models.
        :param encoder: encoder object
        :param decoder: decoder object
        :param src_embed: source data embedding function
        :param tgt_embed: target data embedding function
        :param generator: output part, classification judge generator object
        """
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        return self.decoder(self.encoder(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)


def make_model(src_vocab, tgt_vocab, N=6, d_model=80, d_ff=2048, h=8, dropout=0.1):
    """
    Construct a model from hyperparameter.
    :param src_vocab: source data feature vocabulary size
    :param tgt_vocab: target data feature vocabulary size
    :param N: num of the encoder layer and decoder layer
    :param d_model: word vector dimension
    :param d_ff: the dimension of transform matrix in feedforward fully connected network
    :param h: head num
    :param dropout: set zero rate
    :return:
    """
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)

    # feedforward fully connected network
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab)
    )

    # Important!!! Initialize parameter with Glorot / fan_avg, which ensure that the variance of the outputs of each
    # layer in the network is roughly equal to the variance of its inputs. This helps to prevent the vanishing
    # gradient problem and improve the training of deep neural networks.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)
    print("sussessfully initialize the model")
    return model


def data_gen(V, batch_size, nbatches):
    """Generate random data for src-tgt task"""
    for i in range(nbatches):
        data = torch.randint(1, V, size=(batch_size, 80))
        data[:, 0] = 1
        src = data.requires_grad_(False).clone().detach()
        tgt = data.requires_grad_(False).clone().detach()
        yield Batch(src, tgt, 0)


def rate(step, model_size, factor, warmup):
    """
    We have to default the step to 1 for LambdaLR function to avoid zero raising to negative power.
    lrate = d_{model}^{-0.5} \dot min(step^{-0.5}, step \dot warmup^{-1.5}.
    :param step: learning step
    :param model_size:
    :param factor:
    :param warmup:
    :return:
    """
    if step == 0:
        step = 1
    return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5)))


def example_simple_model(model, loss, epochs=20):
    # lr_scheduler = LambdaLR(
    #     optimizer=optimizer, lr_lambda=lambda step: rate(
    #         step, model_size=model.src_embed[0].d_model, factor=1.0, warmup=400
    #     ))

    batch_size = 80
    for epoch in range(epochs):
        model.train()

        run_epoch(data_gen(V, batch_size, 20), model, loss)
        model.eval()
        run_epoch(data_gen(V, batch_size, 5), model, loss)
    model.eval()
    src = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
    max_len = src.shape[1]
    src_mask = torch.ones(1, 1, max_len)
    print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))


if __name__ == '__main__':
    V = 11
    model = make_model(V, V, N=2)
    criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
    optimizer = get_std_opt(model)
    loss = SimpleLossCompute(model.generator, criterion, optimizer)
    example_simple_model(model, loss)
