import torch
from packaging import version
from torch import distributed as dist
from typing import List, Tuple


def check_env():
    if version.parse(torch.version.cuda) < version.parse("11.3"):
        # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/cudagraph.html
        raise RuntimeError("NCCL CUDA Graph support requires CUDA 11.3 or above")
    if version.parse(version.parse(torch.__version__).base_version) < version.parse("2.2.0"):
        # https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/
        raise RuntimeError(
            "CUDAGraph with NCCL support requires PyTorch 2.2.0 or above. "
            "If it is not released yet, please install nightly built PyTorch with "
            "`pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121`"
        )


def is_power_of_2(n: int) -> bool:
    return (n & (n - 1) == 0) and n != 0


class DistriConfig:
    def __init__(
        self,
        height: int = 1024,
        width: int = 1024,
        do_classifier_free_guidance: bool = True,
        split_batch: bool = True,
        warmup_steps: int = 4,
        comm_checkpoint: int = 60,
        mode: str = "corrected_async_gn",
        use_cuda_graph: bool = True,
        parallelism: str = "patch",
        split_scheme: str = "row",
        verbose: bool = False,
    ):
        try:
            # Initialize the process group
            dist.init_process_group("nccl")
            # Get the rank and world_size
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except Exception as e:
            rank = 0
            world_size = 1
            print(f"Failed to initialize process group: {e}, falling back to single GPU")

        assert is_power_of_2(world_size)
        check_env()

        self.world_size = world_size
        self.rank = rank
        self.height = height
        self.width = width
        self.do_classifier_free_guidance = do_classifier_free_guidance
        self.split_batch = split_batch
        self.warmup_steps = warmup_steps
        self.comm_checkpoint = comm_checkpoint
        self.mode = mode
        self.use_cuda_graph = use_cuda_graph

        self.parallelism = parallelism
        self.split_scheme = split_scheme

        self.verbose = verbose

        if do_classifier_free_guidance and split_batch:
            n_device_per_batch = world_size // 2
            if n_device_per_batch == 0:
                n_device_per_batch = 1
        else:
            n_device_per_batch = world_size

        self.n_device_per_batch = n_device_per_batch

        self.height = height
        self.width = width

        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
        self.device = device

        batch_group = None
        split_group = None
        if do_classifier_free_guidance and split_batch and world_size >= 2:
            batch_groups = []
            for i in range(2):
                batch_groups.append(dist.new_group(list(range(i * (world_size // 2), (i + 1) * (world_size // 2)))))
            batch_group = batch_groups[self.batch_idx()]
            split_groups = []
            for i in range(world_size // 2):
                split_groups.append(dist.new_group([i, i + world_size // 2]))
            split_group = split_groups[self.split_idx()]
        self.batch_group = batch_group
        self.split_group = split_group

    def batch_idx(self, rank: int or None = None) -> int:
        if rank is None:
            rank = self.rank
        if self.do_classifier_free_guidance and self.split_batch:
            return 1 - int(rank < (self.world_size // 2))
        else:
            return 0  # raise NotImplementedError

    def split_idx(self, rank: int or None = None) -> int:
        if rank is None:
            rank = self.rank
        return rank % self.n_device_per_batch


class PatchParallelismCommManager:
    def __init__(self, distri_config: DistriConfig):
        self.distri_config = distri_config

        self.torch_dtype = None
        self.numel = 0
        self.numel_dict = {}

        self.buffer_list = None

        self.starts = []
        self.ends = []
        self.shapes = []

        self.idx_queue = []

        self.handles = None

    def register_tensor(
        self, shape: Tuple[int, ...] or List[int], torch_dtype: torch.dtype, layer_type: str = None
    ) -> int:
        if self.torch_dtype is None:
            self.torch_dtype = torch_dtype
        else:
            assert self.torch_dtype == torch_dtype
        self.starts.append(self.numel)
        numel = 1
        for dim in shape:
            numel *= dim
        self.numel += numel
        if layer_type is not None:
            if layer_type not in self.numel_dict:
                self.numel_dict[layer_type] = 0
            self.numel_dict[layer_type] += numel

        self.ends.append(self.numel)
        self.shapes.append(shape)
        return len(self.starts) - 1

    def create_buffer(self):
        distri_config = self.distri_config
        if distri_config.rank == 0 and distri_config.verbose:
            print(
                f"Create buffer with {self.numel / 1e6:.3f}M parameters for {len(self.starts)} tensors on each device."
            )
            for layer_type, numel in self.numel_dict.items():
                print(f"  {layer_type}: {numel / 1e6:.3f}M parameters")

        self.buffer_list = [
            torch.empty(self.numel, dtype=self.torch_dtype, device=self.distri_config.device)
            for _ in range(self.distri_config.n_device_per_batch)
        ]
        self.handles = [None for _ in range(len(self.starts))]

    def get_buffer_list(self, idx: int) -> List[torch.Tensor]:
        buffer_list = [t[self.starts[idx] : self.ends[idx]].view(self.shapes[idx]) for t in self.buffer_list]
        return buffer_list

    def communicate(self):
        distri_config = self.distri_config
        start = self.starts[self.idx_queue[0]]
        end = self.ends[self.idx_queue[-1]]
        tensor = self.buffer_list[distri_config.split_idx()][start:end]
        buffer_list = [t[start:end] for t in self.buffer_list]
        handle = dist.all_gather(buffer_list, tensor, group=self.distri_config.batch_group, async_op=True)
        for i in self.idx_queue:
            self.handles[i] = handle
        self.idx_queue = []

    def enqueue(self, idx: int, tensor: torch.Tensor):
        distri_config = self.distri_config
        if idx == 0 and len(self.idx_queue) > 0:
            self.communicate()
        assert len(self.idx_queue) == 0 or self.idx_queue[-1] == idx - 1
        self.idx_queue.append(idx)
        self.buffer_list[distri_config.split_idx()][self.starts[idx] : self.ends[idx]].copy_(tensor.flatten())

        if len(self.idx_queue) == distri_config.comm_checkpoint:
            self.communicate()

    def clear(self):
        if len(self.idx_queue) > 0:
            self.communicate()
        if self.handles is not None:
            for i in range(len(self.handles)):
                if self.handles[i] is not None:
                    self.handles[i].wait()
                    self.handles[i] = None
