from abc import ABC
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Type, Union
from torch import nn
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForMaskedLM

from utils.layer_utils import LayerSchema, LayerSpec, LayerType, TransformerLayerSchema


@dataclass
class PrunedTransformerConfig:
    num_heads: List[int]
    hidden_size: List[int]
    intermediate_dimension: List[int]
    head_size: int
    num_layers: int


@dataclass
class ModelMetadata:
    teacher_model: nn.Module
    student_model: nn.Module
    tokenizer: AutoTokenizer
    config: AutoConfig
    schema: List[LayerSchema]

# === 2. Example schemas ===

llama_schema = TransformerLayerSchema(
    layer_type = LayerType.transformer,
    layer_name="LlamaDecoderLayer",
    norm_type="pre_norm",
    layers={
        "q": LayerSpec("self_attn", "q_proj", "row"),
        "k": LayerSpec("self_attn", "k_proj", "row"),
        "v": LayerSpec("self_attn", "v_proj", "row"),
        "o": LayerSpec("self_attn", "o_proj", "column"),
        "fc1": LayerSpec("mlp", "gate_proj", "row"),
        "fc2": LayerSpec("mlp", "up_proj", "row"),
        "fc3": LayerSpec("mlp", "down_proj", "column")
    }
)

# === 3. Schema registries ===

MODEL_SCHEMA_REGISTRY: Dict[str, List[LayerSchema]] = {
    "LlamaForCausalLM": [llama_schema],
    # Add more model class names here
}

LAYER_CLASS_REGISTRY: Dict[str, LayerSchema] = {
    "LlamaDecoderLayer": llama_schema,
    # Add more layer class names here
}

# === 4. Model type mapping (only for loading) ===
MODEL_CLASS_MAP: Dict[str, Type[_BaseAutoModelClass]] = {
    "causal_lm": AutoModelForCausalLM,
    "masked_lm": AutoModelForMaskedLM,
}
