import torch.nn as nn

from .transformer import TransformerBlock
from .embedding import TBTEmbedding
import torch

class GPT(nn.Module):

    def __init__(self, args, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: GPT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads

        self.feed_forward_hidden = hidden * 4

        self.tbt_embedding = TBTEmbedding(args)

        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, attn_heads, hidden * 1, dropout) for _ in range(n_layers)])

        project_layer = nn.Linear(args['hidden'], args['hidden'])
        self.project = nn.Sequential(
            project_layer,
            torch.nn.LeakyReLU())
        self.max_length = args['max_len']
        self.attention_heads = args['attn_heads']
        self.device = args["device"]

    def forward(self, data):
        padding_mask = data['padding_mask']
        batch_list = data['gpt_input']
        ds_to_sub_end = data['ds_to_sub_end']
        real_trigger = ds_to_sub_end
        batch_size = batch_list.size(0)
        real_trigger = real_trigger.unsqueeze(-1).repeat(1,1,self.max_length)

        #[1] padding-mask
        mask = (padding_mask > 0).unsqueeze(1).repeat(1, padding_mask.size(1), 1).unsqueeze(1)

        #[2] decay-mask
        position = torch.arange(self.max_length - 1, -1, step=-1).to(self.device)
        padding_sum = torch.sum(padding_mask, -1)
        max_gap = self.max_length - padding_sum
        max_gap = torch.unsqueeze(max_gap, -1)

        position = torch.unsqueeze(position, 0)
        position = position.repeat(batch_size, 1)
        position = position - max_gap
        position = torch.clamp(position, 0, self.max_length)

        other_decay_mask = position * 10
        other_decay_mask = other_decay_mask.unsqueeze(1).repeat(1, padding_mask.size(1), 1).unsqueeze(1)

        dense_decay_mask = 1.0/(torch.pow(2,8/self.attention_heads * (self.max_length - position)) + 1e-7) * position
        dense_decay_mask = dense_decay_mask.unsqueeze(1).repeat(1, padding_mask.size(1), 1).unsqueeze(1)

        #ds_to_sub_end_mask = ds_to_sub_end.unsqueeze(1).repeat(1, padding_mask.size(1), 1).unsqueeze(1)
        real_trigger = real_trigger.unsqueeze(1)
        #decay_mask = torch.where(real_trigger >= 500, decay_mask, torch.zeros_like(decay_mask))
        decay_mask = torch.where(real_trigger >= 500, other_decay_mask, dense_decay_mask)

        #[3] logic
        x = self.tbt_embedding(data)
        x_ori = self.project(x)

        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask, decay_mask)

        return x, x_ori
