from typing import Any, Dict

from .cogvideox import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, COGVIDEOX_T2V_LORA_CONFIG
from .hunyuan_video import (
    HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG,
    HUNYUAN_VIDEO_T2V_LORA_CONFIG,
)
from .ltx_video import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, LTX_VIDEO_T2V_LORA_CONFIG
from .wan import WAN_T2V_FULL_FINETUNE_CONFIG, WAN_T2V_LORA_CONFIG


SUPPORTED_MODEL_CONFIGS = {
    "hunyuan_video": {
        "lora": HUNYUAN_VIDEO_T2V_LORA_CONFIG,
        "full-finetune": HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG,
    },
    "ltx_video": {
        "lora": LTX_VIDEO_T2V_LORA_CONFIG,
        "full-finetune": LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG,
    },
    "cogvideox": {
        "lora": COGVIDEOX_T2V_LORA_CONFIG,
        "full-finetune": COGVIDEOX_T2V_FULL_FINETUNE_CONFIG,
    },
    "wan": {
        "lora": WAN_T2V_LORA_CONFIG,
        "full-finetune": WAN_T2V_FULL_FINETUNE_CONFIG,
    },
}


def get_config_from_model_name(model_name: str, training_type: str) -> Dict[str, Any]:
    if model_name not in SUPPORTED_MODEL_CONFIGS:
        raise ValueError(
            f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
        )
    if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
        raise ValueError(
            f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
        )
    return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
