import json
import logging
import os
import hydra
import torch
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf
from pathlib import Path

from auden.trainer.audio_tag import AudioTagTrainer as Trainer
from auden.models.audio_tag.utils import load_id2label, save_id2label
from auden.auto.auto_model import AutoModel
from auden.auto.auto_config import AutoConfig
#from utils.utils import fix_random_seed

logger = logging.getLogger(__name__)

@hydra.main(version_base=None, config_path="configs/", config_name="train") 
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # # 1) Fix random seed if desired
    # if "seed" in cfg:
    #     fix_random_seed(cfg.seed)

    # 2) Gather torchrun environment variables
    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    # 3) Initialize process group if multi-GPU
    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl", init_method="env://")

    # 4) (Optional) Create experiment directory
    if "exp_dir" in cfg and cfg.exp_dir:
        os.makedirs(cfg.exp_dir, exist_ok=True)

    # 5) initialize model    
    config = AutoConfig.from_hydra_cfg(cfg.model)
    id2label = load_id2label(cfg.model.id2label_json) 
    model = AutoModel.from_config(config, id2label=id2label)
    if rank == 0:
        config.save_config(cfg.exp_dir)
        save_id2label(id2label, cfg.exp_dir)

    # 6) Create the trainer, passing the model
    trainer = Trainer(cfg, model, rank=rank, local_rank=local_rank, world_size=world_size)
    trainer.run()

    # 7) Destroy process group if used
    if world_size > 1:
        dist.destroy_process_group()

    logging.info("Training finished successfully.")

if __name__ == "__main__":
    main()
