import importlib
import json
from pathlib import Path
import importlib
import json
from pathlib import Path

import json
import importlib
from argparse import Namespace
import inspect
import torch.nn as nn

class ModelFactoryp:
    def __init__(self, registry_path='configs/model_registry.json'):
        self.registry = self._load_registry(registry_path)
    
    def _load_registry(self, path):
        """加载模型注册表"""
        with open(path, 'r') as f:
            return json.load(f)
    
    def create_model(self, model_type, **kwargs):
        """根据模型类型创建模型实例"""
        if model_type not in self.registry:
            raise ValueError(f"未知模型类型: {model_type}")
        
        model_info = self.registry[model_type]
        
        # 动态导入模块
        try:
            module = importlib.import_module("models." + model_info['module'])
            model_class = getattr(module, model_info['class'])
        except (ImportError, AttributeError) as e:
            raise ImportError(f"无法导入模型 {model_type}: {e}")
        
        # 合并默认配置和传入参数
        config = model_info.get('default_config', {}).copy()
        config.update(kwargs)
        
        # 验证必需参数
        for param in model_info['required_params']:
            if param not in config:
                raise ValueError(f"模型 {model_type} 需要参数: {param}")
        
        # 打印有效参数
        print(f"为模型 {model_type} 传递的有效参数:")
        for key, value in config.items():
            print(f"  - {key}: {value}")
        print("-" * 20)


        init_signature = inspect.signature(model_class.__init__)
        init_params = list(init_signature.parameters.keys())
        

        if len(init_params) == 2 and 'args' in init_params:

            args = Namespace(**config)
            print("模型构造函数接受 'args' 对象，将参数打包。")
            return model_class(args).float()

        else:
            print("模型构造函数接受关键字参数，直接解包参数。")
            return model_class(**config).float()
    
    def get_available_models(self):
        """获取所有可用模型列表"""
        return list(self.registry.keys())
    
    def get_model_info(self, model_type):
        """获取特定模型的配置信息"""
        return self.registry.get(model_type, {})


class ModelFactory:
    def __init__(self, registry_path='configs/model_registry.json'):
        self.registry = self._load_registry(registry_path)
    
    def _load_registry(self, path):
        """加载模型注册表"""
        with open(path, 'r') as f:
            return json.load(f)
    
    def create_model(self, model_type, **kwargs):
        """根据模型类型创建模型实例"""
        if model_type not in self.registry:
            raise ValueError(f"未知模型类型: {model_type}")
        
        model_info = self.registry[model_type]
        

        module = importlib.import_module("models." + model_info['module'])
        model_class = getattr(module, model_info['class'])

        config = model_info.get('default_config', {}).copy()
        config.update(kwargs)
        
        # 验证必需参数
        for param in model_info['required_params']:
            if param not in config:
                raise ValueError(f"模型 {model_type} 需要参数: {param}")
         # 打印有效参数
        print(f"为模型 {model_type} 传递的有效参数:")
        for key, value in config.items():
            print(f"  - {key}: {value}")
        print("-" * 20)
        
        return model_class(**config).float()
    
    def get_available_models(self):

        return list(self.registry.keys())
    
    def get_model_info(self, model_type):

        return self.registry.get(model_type, {})


