# main.py
import logging
import os
import hydra
import json
import torch
import deepspeed
from omegaconf import DictConfig, OmegaConf

from auden.models.model_module import build_model
from auden.trainer.asr_ds_trainer import IcefallAsrTrainer as Trainer

@hydra.main(version_base=None, config_path="configs", config_name="asr_ds")
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # DeepSpeed automatically handles distributed initialization.
    # You may still want to set the CUDA device based on LOCAL_RANK.
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank)

    # Optionally create experiment directory
    if "exp_dir" in cfg and cfg.exp_dir:
        os.makedirs(cfg.exp_dir, exist_ok=True)

    # 1) Build the model
    model = build_model(cfg.model)
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")
    
    if model.config.num_experts > 1:
        def create_moe_param_groups(model):
            from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer

            parameters = {'params': [p for p in model.parameters()], 'name': 'parameters'}

            return split_params_into_different_moe_groups_for_optimizer(parameters)
        parameters = create_moe_param_groups(model)
    else:
        parameters = model.parameters()

    # 2) Initialize DeepSpeed engine.
    # The deepspeed configuration (e.g., optimizer, fp16, ZeRO settings) should be in cfg.deepspeed.
    # Load the DeepSpeed configuration from the external JSON file.
    with open(cfg.deepspeed_config, "r") as f:
        ds_config = json.load(f)
        
    model_engine, optimizer, _, scheduler = deepspeed.initialize(
        model=model,
        model_parameters=parameters,
        config=ds_config,
    )

    # 3) Create the trainer, passing the DeepSpeed engine.
    # (Your trainer should be adapted to work with a DeepSpeed engine. For example, use
    # model_engine.backward(loss) and model_engine.step() inside the training loop.)
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    trainer = Trainer(cfg, model_engine, optimizer, scheduler, world_size)
    trainer.run()

    logging.info("Training finished successfully.")

if __name__ == "__main__":
    logging.getLogger("deepspeed").setLevel(logging.WARNING)
    main()
