from typing import Iterable

import torch
import torch.nn as nn

from transformers.models.gpt_neox.modeling_gpt_neox import (
    GPTNeoXForCausalLM,
    GPTNeoXModel,
)

from .embeddings import PartlyFrozenEmbeddings
from .base import PreTrainedModelForAIM


class GPTNeoXForAIM(PreTrainedModelForAIM):
    """
    GPT-NeoX for Attention Influence Modeling (AIM)
    """
    
    def get_base_model(self, model: GPTNeoXForCausalLM) -> GPTNeoXModel:
        return model.gpt_neox

    def get_layers(self) -> Iterable[nn.Module]:
        return self.model.layers
    
    def set_layers(self, layers: Iterable[nn.Module]):
        self.model.layers = nn.ModuleList(layers)
    
    def create_partly_frozen_embeddings(self, frozen_mask: torch.Tensor) -> nn.Embedding:
        return PartlyFrozenEmbeddings(
            embeddings=self.model.embed_in,
            frozen_mask=frozen_mask,
        )
