import json
import os
import warnings
from typing import Optional, Union

import torch
from transformers import AutoConfig
from transformers import AutoModelForCausalLM as AutoModelForCausalLMBase
from transformers import (
    LlamaConfig,
    PretrainedConfig,
    Qwen3Config,
    modeling_utils,
)

import setup_path
from train.utils import default_torch_dtype
from train.modeling.modeling_qwen3 import Qwen3ForCausalLM
from train.modeling.modeling_llama import LlamaForCausalLM

class AutoDistributedModelForCausalLM(AutoModelForCausalLMBase):
    _model_mapping = {
        Qwen3Config: [Qwen3ForCausalLM],
        LlamaConfig: [LlamaForCausalLM],
    }

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike[str]],
        torch_dtype: torch.dtype = None,
        device: str = None,
        cache_dir: Optional[str] = None,
        **config_kwargs,
    ):
        config = AutoConfig.from_pretrained(
            pretrained_model_name_or_path, **config_kwargs
        )

        assert (
            type(config) in cls._model_mapping
        ), f"Unsupported config type: {type(config)}"
        model_cls = cls._model_mapping[type(config)][0]

        if device is None:
            device = torch.device("cpu")
        else:
            device = torch.device(device)

        if torch_dtype is None:
            torch_dtype = torch.get_default_dtype()

        # load model
        with default_torch_dtype(torch_dtype), torch.device(device):
            model = model_cls(config)
        model.load_checkpoint(pretrained_model_name_or_path, cache_dir=cache_dir)

        # just ensure that all the parameters follow the same dtype and device
        # model = model.to(torch_dtype)
        # model = model.to(device)

        return model