from dataclasses import dataclass, field
from transformers import AutoTokenizer
from torch import nn

@dataclass
class ModelConfig:

    tokenizer: AutoTokenizer
    
    ctx_len: int = 1024
    vocab_size: int = 50304
    num_layers: int = 8
    num_heads: int = 8
    num_key_value: int = 2
    attn_bias: bool = True
    eos_token_id: int = -1
    embed_dim: int = 512
    mlp_dim: int = 512 * 4

    #routed model options
    target_layers: list[int] = field(default_factory=list)
    aux_labels: list[str] = field(default_factory=list)

#ABC
class Transformer(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.model_type = None