# main.py
import logging
import os
import hydra
import torch
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf

# from auden.auto.auto_model import AutoModel
from auden.models.model_module import build_model
from auden.trainer.whisper_trainer import WhisperAsrTrainer as Trainer
#from utils.utils import fix_random_seed

@hydra.main(version_base=None, config_path="configs", config_name="asr")
def main(cfg: DictConfig):
    logging.basicConfig()
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # # 1) Fix random seed if desired
    # if "seed" in cfg:
    #     fix_random_seed(cfg.seed)

    # 2) Gather torchrun environment variables
    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", rank))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    if rank > 0:
        logging.getLogger().setLevel(logging.WARNING)

    # 3) Initialize process group if multi-GPU
    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl", init_method="env://")

    # 4) (Optional) Create experiment directory
    if "exp_dir" in cfg and cfg.exp_dir:
        os.makedirs(cfg.exp_dir, exist_ok=True)

    # 5) initialize model
    model = build_model(cfg.model)
    num_param = sum([p.numel() for p in model.parameters()])
    logging.info(f"Number of model parameters: {num_param}")

    # 6) Create the trainer, passing the model
    trainer = Trainer(cfg, model, rank=rank, local_rank=local_rank, world_size=world_size)
    trainer.run()

    # 7) Destroy process group if used
    if world_size > 1:
        dist.destroy_process_group()

    logging.info("Training finished successfully.")

if __name__ == "__main__":
    main()
