# mmseg/utils/log_trainable_params_hook.py
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
import os

@HOOKS.register_module()
# class LogTrainableParamsHook(Hook):
#     def before_train(self, runner):
#         model = runner.model
#         total, trainable = 0, 0
#         for name, param in model.named_parameters():
#             num = param.numel()
#             total += num
#             if param.requires_grad:
#                 trainable += num
#         log_str = f'[LogTrainableParamsHook] Total params: {total:,} | Trainable: {trainable:,}'
#         runner.logger.info(log_str)

#         # 保存到 work_dir/log_trainable_params.txt
#         save_path = os.path.join(runner.work_dir, 'log_trainable_params.txt')
#         with open(save_path, 'w') as f:
#             f.write(log_str + '\n')

class LogTrainableParamsHook(Hook):
    def before_train(self, runner):
        model = runner.model
        print('\n🚨 [Trainable Parameters]')
        total = 0
        for name, param in model.named_parameters():
            if param.requires_grad:
                print(f"✅ {name} - {param.shape}")
                total += param.numel()
        print(f"🔢 Total trainable parameters: {total}\n")
