import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd

from src.modules.nn import MLP, CrossAttention, SelfAttention, LayerNorm



class SelfAttentionBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn_1 = SelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn_1(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
    
    
class AttentionTerminal(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.transformer = nn.ModuleDict(dict(
            h = nn.ModuleList([SelfAttentionBlock(config) for i in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.logits = nn.Linear(config.n_embd, 1)

    def forward(self, x):
        for block in self.transformer.h:
            x = block(x)
        x = self.logits(x[:, -1])
        return x
    
    
class AttentionReward(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.transformer = nn.ModuleDict(dict(
            h = nn.ModuleList([SelfAttentionBlock(config) for i in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.logits = nn.Linear(config.n_embd, 1)

    def forward(self, x):
        for block in self.transformer.h:
            x = block(x)
        x = self.logits(x[:, -1])
        return x