import os
import importlib

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from config.parse_yaml_args import load_config
# import deepspeed
# deepspeed.ops.op_builder.CPUAdamBuilder().load()


def get_module(cfg, return_class=False):
    def convert_name(snake_name: str):
        # 用于将蛇形命名转换为驼峰式命名
        words = snake_name.split('_')
        camel_case_name = ''.join(word.capitalize() for word in words) + 'Module'
        return camel_case_name
    module_cfg = cfg["module_cfg"]
    name = module_cfg["name"]
    module_cfg.pop("name")
    # module_path = 'models.' + name + '.' + convert_name(name)
    module_path = 'models.' + name
    # getattr函数作用：动态属性访问，提供默认值
    module_cls = getattr(importlib.import_module(module_path), convert_name(name))
    if return_class:
        return module_cls
    return module_cls(cfg, **module_cfg)


if __name__ == "__main__":

    os.chdir(os.path.dirname(os.path.abspath(__file__)))
    # cfg = load_config()
    config = load_config()

    ckpt_cfg = config['ckpt_cfg']
    # 使用配置中的随机种子初始化随机数生成器。
    seed_everything(config['seed'])

    # 创建模型检查点回调，用于在训练过程中保存模型状态。
    checkpoint_callback = ModelCheckpoint(
        **ckpt_cfg,
        auto_insert_metric_name=False,
        save_top_k=30,  # save top 10
        save_last=True,
    )

    # try:
    #     logger = WandbLogger(
    #         project=cfg['train_name'],
    #         name=f"{cfg['train_name']}-{cfg['train_id']}",
    #         version=cfg['train_id'],
    #         save_dir=osp.join(cfg['log_dir'], cfg['train_name'], cfg['train_id']),
    #         log_model=False
    #     )
    # except:
        # print('down to tensorboard')

    # 日志记录器
    logger = TensorBoardLogger(
        save_dir=config['log_dir'],
        name=config['train_name'],
        version=config['train_id']
    )
    # 回调列表，包括检查点回调和学习率监控
    callback_list = [
        checkpoint_callback,
        LearningRateMonitor(logging_interval="step")
        ]
    if 'restart_from_path' in config:
        module_class = get_module(config, return_class=True)
        module_config = config['module_cfg']
        module = module_class.load_from_checkpoint(
            config['restart_from_path'],
            cfg=config,
            **module_config,
            map_location='cpu',
            strict=False
        )
        print('restart_from_path:', config['restart_from_path'])
    else:
        module = get_module(config)


    if config['use_deepspeed']:
        strategy = DeepSpeedStrategy(
            **config['deepspeed_cfg'],
        )
    elif 'DDPStrategy' in config:
        strategy = DDPStrategy(
            **config['DDPStrategy'],
        )
    else:
        strategy = DDPStrategy(find_unused_parameters=True)

    # 配置训练器参数，包括加速器类型、日志记录器、回调列表和训练策略。
    trainer = Trainer(
        **config['trainer_cfg'],
        accelerator="gpu",
        logger=logger,
        callbacks=callback_list,
        strategy=strategy,
        )

    # if cfg.get('val_first', False):
    # trainer.validate(module)

    if config['mode'] == "test_last":
        trainer.validate(module, ckpt_path="last")
    elif config['mode'] == "test_path":
        trainer.validate(module, ckpt_path=config['test_path'])
    elif config['mode'] == "eval":
        trainer.validate(module)
    elif config['mode'] == "resume":
        trainer.fit(module, ckpt_path=config['resume_from_path'])
    else:
        trainer.fit(module, ckpt_path='last')
