import os
import logging
import contextlib

import torch

from common.conf import BaseEnvVars, BaseDistRunArgs
from training.distributed.env import get_global_rank, get_local_rank, get_world_size, get_master_addr, get_master_port

logger = logging.getLogger()


def setup_env_args(args: BaseEnvVars):
 env_vars = args.model_dump()

 for name, value in env_vars.items():
 if os.environ.get(name) != str(value):
 os.environ[name] = str(value)
 logger.warning(f"Setting {name} to {value}")


def setup_torch_distributed(args: BaseDistRunArgs):
 """
 Handle single and multi-GPU / multi-node / SLURM jobs.
 Initialize the following variables:
 - global_rank
 - world_size
 """
 rank = get_global_rank()
 local_rank = get_local_rank()
 world_size = get_world_size()
 master_addr = get_master_addr()
 master_port = get_master_port(
 job_id=int(os.environ.get("SLURM_JOB_ID", -1)),
 port=args.master_port,
 is_port_random=args.is_port_random
 )

 os.environ["RANK"] = str(rank)
 os.environ["WORLD_SIZE"] = str(world_size)
 os.environ["MASTER_ADDR"] = master_addr
 os.environ["MASTER_PORT"] = str(master_port)
 os.environ["LOCAL_RANK"] = str(local_rank)

 # set GPU device
 assert 0 <= local_rank < 8
 # if torch.cuda.device_count() > 1:
 torch.cuda.set_device(local_rank)

 torch.distributed.init_process_group(
 backend="nccl",
 # init_method="env://",
 # rank=rank,
 # world_size=world_size,
 device_id=torch.device(f"cuda:{local_rank}"),
 )

 return rank, local_rank, world_size

def clean_torch_distributed(local_rank: int):
 # deepspeed env can also be cleand
 # Avoid hanging forever if another rank has crashed.
 if not (torch.distributed.is_available() and torch.distributed.is_initialized()):
 return

 try:
 # Prefer a device-aware barrier when possible.
 if torch.cuda.is_available():
 torch.distributed.barrier(device_ids=[local_rank])
 else:
 torch.distributed.barrier()
 except Exception as e:
 logger.warning("Distributed barrier during cleanup failed: %s", e, exc_info=True)
 finally:
 with contextlib.suppress(Exception):
 torch.distributed.destroy_process_group()


