import torch.nn
from transformer_blocks.transformer_decoders import TransformerDecoder

from ltsgns_mp.util.own_types import ConfigDict


class ContextTransformer(torch.nn.Module):
    def __init__(self, config: ConfigDict, r_dim: int, device: str):
        super().__init__()
        self.config = config
        self._context_transformer = TransformerDecoder(
            embed_dim=config.latent_dimension,
            n_heads=config.n_heads,
            attn_pdrop=config.attn_pdrop,
            resid_pdrop=config.resid_pdrop,
            n_layers=config.n_layers,
            block_size=config.block_size,
            bias=config.bias,
            use_rms_norm=config.use_rms_norm,
            use_rot_embed=config.use_rot_embed,
            use_relative_pos=config.use_relative_pos,
            rotary_xpos=config.rotary_xpos,
            mlp_pdrop=config.mlp_pdrop,
        )
        self._embedding = torch.nn.Linear(r_dim, config.latent_dimension)
        self._output_proj = torch.nn.Linear(config.latent_dimension, r_dim)
        # to device
        self._context_transformer.to(device)
        self._embedding.to(device)
        self._output_proj.to(device)

    def forward(self, per_graph_r: torch.Tensor) -> torch.Tensor:
        # shape of per_graph_r: (batch_size, seq_length, r_dim)
        embedded_r = self._embedding(per_graph_r)
        # get output from last token
        r = self._context_transformer(embedded_r)[:, -1, :]
        return self._output_proj(r)



