"""
Utilities for distributed training.

This module provides helper functions and classes for distributed data parallel training.
"""

import os
import torch
from operator import itemgetter


class DatasetFromSampler(torch.utils.data.Dataset):
    """Dataset to create indexes from a sampler

    Args:
        sampler: PyTorch sampler
    """

    def __init__(self, sampler):
        self.sampler = sampler
        self.sampler_list = None

    def __getitem__(self, index):
        """Get element by index"""
        if self.sampler_list is None:
            self.sampler_list = list(self.sampler)
        return self.sampler_list[index]

    def __len__(self):
        """Get dataset length"""
        return len(self.sampler)


class DistributedSamplerWrapper(torch.utils.data.distributed.DistributedSampler):
    """Wrapper over Sampler for distributed training

    Allows using any sampler in distributed mode. Useful with
    torch.nn.parallel.DistributedDataParallel.

    Args:
        sampler: Sampler used for subsampling
        num_replicas: Number of processes in distributed training
        rank: Rank of current process
        shuffle: Whether to shuffle indices
    """

    def __init__(
            self,
            sampler,
            num_replicas=None,
            rank=None,
            shuffle=True,
    ):
        super(DistributedSamplerWrapper, self).__init__(
            DatasetFromSampler(sampler),
            num_replicas=num_replicas,
            rank=rank,
            shuffle=shuffle,
        )
        self.sampler = sampler

    def __iter__(self):
        """Iterator over sampler indices"""
        self.dataset = DatasetFromSampler(self.sampler)
        indexes_of_indexes = super().__iter__()
        subsampler_indexes = self.dataset
        return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))


def setup_ddp(rank, world_size, port=12357):
    """Setup distributed data parallel process group

    Args:
        rank: Rank of current process
        world_size: Number of processes
        port: Port for communication
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)

    # Initialize the process group
    torch.distributed.init_process_group(
        "nccl",
        rank=rank,
        world_size=world_size,
    )
    torch.cuda.set_device(rank)
    torch.distributed.barrier(device_ids=[rank])


def cleanup_ddp():
    """Clean up distributed data parallel process group"""
    torch.distributed.destroy_process_group()


def is_main_process():
    """Check if current process is the main process"""
    return torch.distributed.get_rank() == 0


def distribute_loader(loader):
    """Convert a data loader to a distributed data loader

    Args:
        loader: Original data loader

    Returns:
        Distributed data loader
    """
    if isinstance(loader.sampler, torch.utils.data.SubsetRandomSampler):
        sampler = DistributedSamplerWrapper(
            loader.sampler,
            num_replicas=torch.distributed.get_world_size(),
            rank=torch.distributed.get_rank(),
        )
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(
            loader.dataset,
            num_replicas=torch.distributed.get_world_size(),
            rank=torch.distributed.get_rank(),
        )

    return torch.utils.data.DataLoader(
        loader.dataset,
        batch_size=loader.batch_size // torch.distributed.get_world_size(),
        sampler=sampler,
        num_workers=loader.num_workers,
        pin_memory=getattr(loader, "pin_memory", True),
    )