import logging

import torch.distributed as dist
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.state import AcceleratorState
from CoLM.option import TrainArg

logger = logging.getLogger("Distributed Training Setting ...")


def build_distributed_configuration(args: TrainArg):
    if args.distributed_type == "Pytorch":
        return DistributedConfiguration(args=args)
    elif args.distributed_type == "Accelerator":
        return HFDistributedConfiguration(args=args)
    else:
        raise Exception("Please specific your distributed configuration.")


class DistributedConfiguration:
    args: TrainArg
    accelerator: Accelerator = None

    def __init__(self, args: TrainArg):
        self.args = args

        self.distributed_init()
        self.set_logging_level()

    def distributed_init(self):
        if self.args.use_deepspeed:
            try:
                import deepspeed
                deepspeed.init_distributed(
                    dist_backend=self.args.distributed_backend,
                    rank=self.args.distributed_rank,
                    world_size=self.args.distributed_world_size,
                )
            except ImportError:
                raise ImportError("Please install deepspeed")
        else:
            dist.init_process_group(
                backend=self.args.distributed_backend,
                rank=self.args.distributed_rank,
                world_size=self.args.distributed_world_size
            )

    def set_logging_level(self):
        logger.info(
            "setting CUDA device={} on rank {}".format(
                self.args.local_rank, self.args.distributed_rank,
            )
        )
        if self.is_distributed():
            dist.barrier()

        if self.is_master:
            logging.getLogger().setLevel(logging.INFO)
        else:
            logging.getLogger().setLevel(logging.WARNING)

    @property
    def is_master(self):
        return self.args.distributed_rank == 0

    @property
    def use_deepseed(self):
        return self.args.use_deepspeed

    def is_distributed(self):
        return dist.is_available() and dist.is_initialized()

    def prepare(self, *args):
        return args


class HFDistributedConfiguration(DistributedConfiguration):

    def __init__(self, args: TrainArg):
        super().__init__(args=args)

    def distributed_init(self):
        if self.use_deepseed:
            deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.deepspeed_config)
            self.accelerator = Accelerator(
                mixed_precision=self.args.precision,
                deepspeed_plugin=deepspeed_plugin,
                gradient_accumulation_steps=self.args.gradient_accumulation_step,
            )
        else:
            self.accelerator = Accelerator(
                mixed_precision=self.args.precision,
                gradient_accumulation_steps=self.args.gradient_accumulation_step,
            )
        if self.is_master:
            self.accelerator.print(f"{AcceleratorState()}")

    def prepare(self, *args):
        return self.accelerator.prepare(*args)