import os
import json
import logging
import dataclasses
from dataclasses import asdict
import torch

from omegaconf import DictConfig, ListConfig, OmegaConf
from .auto_config import AutoConfig
from ..models.model_registry import MODEL_REGISTRY

logger = logging.getLogger(__name__)

class AutoModel:
    @classmethod
    def from_pretrained(
        cls,
        exp_dir: str,
        **kwargs,  # checkpoint_filename, iter, epoch, avg, strict
    ):
        """
        Load a model from exp_dir:
        - checkpoint will be loaded based on kwargs
        """
        # ---- Load config ----
        config = AutoConfig.from_pretrained(exp_dir)
        
        ModelClass = MODEL_REGISTRY[config.name]
        model = ModelClass.from_pretrained(exp_dir, **kwargs)
        return model

    @classmethod
    def from_config(cls, config, **kwargs):
        model_type = getattr(config, "name", None)
        if model_type is None:
            raise ValueError("Config object must have 'name' field.")
        if model_type not in MODEL_REGISTRY:
            raise ValueError(f"Unknown model name: {model_type}")

        ModelClass = MODEL_REGISTRY[model_type]
        return ModelClass(config, **kwargs)
