import logging
import os
import hydra
import torch
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf

from auden.trainer.tta import TtaTrainer as Trainer
from auden.auto.auto_model import AutoModel
from auden.auto.auto_config import AutoConfig
from auden.auto.auto_tokenizer import AutoTokenizer, save_tokenizer
# from utils.utils import fix_random_seed

logger = logging.getLogger(__name__)

@hydra.main(version_base=None, config_path="configs", config_name="train")
def main(cfg: DictConfig):
    logger.info("\n" + OmegaConf.to_yaml(cfg))

    # # 1) Fix random seed if desired
    # if "seed" in cfg:
    #     fix_random_seed(cfg.seed)

    # 2) Gather torchrun environment variables
    rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    # 3) Initialize process group if multi-GPU
    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl", init_method="env://")

    # 4) (Optional) Create experiment directory
    if "exp_dir" in cfg and cfg.exp_dir:
        os.makedirs(cfg.exp_dir, exist_ok=True)
        
    # 5) initialize model
    tokenizer = AutoTokenizer.from_pretrained(**cfg.tokenizer)
    if rank == 0:
        save_tokenizer(cfg.tokenizer.source, cfg.exp_dir)
    config = AutoConfig.from_hydra_cfg(cfg.model)
    model = AutoModel.from_config(config, tokenizer=tokenizer)
    if rank == 0:
        config.save_config(cfg.exp_dir)

    # 6) Create the trainer, passing the model
    trainer = Trainer(cfg, model, rank=rank, local_rank=local_rank, world_size=world_size)
    trainer.run()

    # 7) Destroy process group if used
    if world_size > 1:
        dist.destroy_process_group()

    logging.info("Training finished successfully.")

if __name__ == "__main__":
    main()
