from torch import nn

from models.modules import CausalSelfAttention, Mlp
from models.ops import process_rollout_attention

class GPT2Block(nn.Module): 

    def __init__(self, dim, num_heads, attn_drop=0.1, resid_pdrop=0.1, mlp_ratio=4.0):
        super(GPT2Block, self).__init__()
        mlp_hidden_dim = int(mlp_ratio * dim)
        assert mlp_hidden_dim > dim, '...'
        self.ln1 = nn.LayerNorm(dim)
        self.attn = CausalSelfAttention(dim, num_heads, attn_drop=attn_drop, resid_pdrop=resid_pdrop)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=nn.GELU, drop=[0.0, resid_pdrop])
        
    def forward(self, x, attn_span, return_attention=False):
        attn_outputs = self.attn(self.ln1(x), attn_span=attn_span, 
                                 return_weights=return_attention)
        y = attn_outputs.get('outputs')
        attn_weights = attn_outputs.get('weights')
        x = x + y
        x = x + self.mlp(self.ln2(x))
        out = {'outputs': x, 'attentions_weights': attn_weights}
        return out

class GPT2(nn.Module):
    def __init__(self, num_layers, dim, num_heads, **kwargs):
        super(GPT2, self).__init__()
        self.blocks = nn.ModuleList([GPT2Block(dim, num_heads, **kwargs)\
            for _ in range(num_layers)])
    def forward(self, x, attn_span=1, rollout_attention=False):
        y = x
        attentions_rollout = None
        for idx, block in enumerate(self.blocks):
            block_outputs = block(y, attn_span=attn_span, 
                                  return_attention=rollout_attention)
            y = block_outputs.get('outputs')
            attn_weights = block_outputs.get('attentions_weights')
            if rollout_attention:
                attn_weights = process_rollout_attention(attn_weights)
            if (idx > 0) and (attentions_rollout is not None):
                attentions_rollout = attn_weights.bmm(attentions_rollout)
            else:
                attentions_rollout = attn_weights
        out = {'outputs': y,'attentions_rollout': attentions_rollout} 
        return out

