import logging

import torch
import torch.distributed as dist

from torch.nn.parallel import DistributedDataParallel
from LLMProxy.option import DistArg

logger = logging.getLogger("Distributed Training Setting ...")


def is_master(args: DistArg):
    # check whether is master node
    return args.local_rank == 0


def distributed_init(args: DistArg):
    if args.use_deepspeed:
        try:
            import deepspeed
            deepspeed.init_distributed(
                dist_backend=args.distributed_backend,
                rank=args.distributed_rank,
                world_size=args.distributed_world_size,
            )
        except ImportError:
            raise ImportError("Please install deepspeed")
    else:
        dist.init_process_group(
            backend=args.distributed_backend,
            rank=args.distributed_rank,
            world_size=args.distributed_world_size,
        )
    
    logger.info(
        "setting CUDA device={} on rank {}".format(
            args.local_rank, args.distributed_rank,
        )
    )
    dist.barrier()

    # suppress logging to output at non-master node
    if is_master(args):
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.getLogger().setLevel(logging.WARNING)


def is_distributed():
    return dist.is_available() and dist.is_initialized()


def create_ddp_model(model: torch.nn.Module, args: DistArg):
    model = DistributedDataParallel(
        model.to(args.local_rank),
        device_ids=[args.local_rank],
        output_device=args.local_rank,
        find_unused_parameters=True,
    )
    return model