import torch
import torch.distributed as dist
from . import dist_util


def get_generator(generator, num_samples=0, seed=0):
    if generator == "dummy":
        return DummyGenerator()
    elif generator == "determ":
        return DeterministicGenerator(num_samples, seed)
    elif generator == "determ-indiv":
        return DeterministicIndividualGenerator(num_samples, seed)
    else:
        raise NotImplementedError


class DummyGenerator:
    def randn(self, *args, **kwargs):
        return torch.randn(*args, **kwargs)

    def randint(self, *args, **kwargs):
        return torch.randint(*args, **kwargs)

    def randn_like(self, *args, **kwargs):
        return torch.randn_like(*args, **kwargs)


class DeterministicGenerator:
    """
    RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
    Uses a single rng and samples num_samples sized randomness and subsamples the current indices
    """

    def __init__(self, num_samples, seed=0):
        if dist.is_initialized():
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        else:
            print("Warning: Distributed not initialised, using single rank")
            self.rank = 0
            self.world_size = 1
        self.num_samples = num_samples
        self.done_samples = 0
        self.seed = seed
        self.rng_cpu = torch.Generator()
        if torch.cuda.is_available():
            self.rng_cuda = torch.Generator(dist_util.dev())
        self.set_seed(seed)

    def get_global_size_and_indices(self, size):
        global_size = (self.num_samples, *size[1:])
        indices = torch.arange(
            self.done_samples + self.rank,
            self.done_samples + self.world_size * int(size[0]),
            self.world_size,
        )
        indices = torch.clamp(indices, 0, self.num_samples - 1)
        assert len(indices) == size[0], f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
        return global_size, indices

    def get_generator(self, device):
        return self.rng_cpu if torch.device(device).type == "cpu" else self.rng_cuda

    def randn(self, *size, dtype=torch.float, device="cpu"):
        global_size, indices = self.get_global_size_and_indices(size)
        generator = self.get_generator(device)
        return torch.randn(*global_size, generator=generator, dtype=dtype, device=device)[indices]

    def randint(self, low, high, size, dtype=torch.long, device="cpu"):
        global_size, indices = self.get_global_size_and_indices(size)
        generator = self.get_generator(device)
        return torch.randint(low, high, generator=generator, size=global_size, dtype=dtype, device=device)[indices]

    def randn_like(self, tensor):
        size, dtype, device = tensor.size(), tensor.dtype, tensor.device
        return self.randn(*size, dtype=dtype, device=device)

    def set_done_samples(self, done_samples):
        self.done_samples = done_samples
        self.set_seed(self.seed)

    def get_seed(self):
        return self.seed

    def set_seed(self, seed):
        self.rng_cpu.manual_seed(seed)
        if torch.cuda.is_available():
            self.rng_cuda.manual_seed(seed)


class DeterministicIndividualGenerator:
    """
    RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
    Uses a separate rng for each sample to reduce memoery usage
    """

    def __init__(self, num_samples, seed=0):
        if dist.is_initialized():
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        else:
            print("Warning: Distributed not initialised, using single rank")
            self.rank = 0
            self.world_size = 1
        self.num_samples = num_samples
        self.done_samples = 0
        self.seed = seed
        self.rng_cpu = [torch.Generator() for _ in range(num_samples)]
        if torch.cuda.is_available():
            self.rng_cuda = [torch.Generator(dist_util.dev()) for _ in range(num_samples)]
        self.set_seed(seed)

    def get_size_and_indices(self, size):
        indices = torch.arange(
            self.done_samples + self.rank,
            self.done_samples + self.world_size * int(size[0]),
            self.world_size,
        )
        indices = torch.clamp(indices, 0, self.num_samples - 1)
        assert len(indices) == size[0], f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
        return (1, *size[1:]), indices

    def get_generator(self, device):
        return self.rng_cpu if torch.device(device).type == "cpu" else self.rng_cuda

    def randn(self, *size, dtype=torch.float, device="cpu"):
        size, indices = self.get_size_and_indices(size)
        generator = self.get_generator(device)
        return torch.cat(
            [torch.randn(*size, generator=generator[i], dtype=dtype, device=device) for i in indices],
            dim=0,
        )

    def randint(self, low, high, size, dtype=torch.long, device="cpu"):
        size, indices = self.get_size_and_indices(size)
        generator = self.get_generator(device)
        return torch.cat(
            [
                torch.randint(
                    low,
                    high,
                    generator=generator[i],
                    size=size,
                    dtype=dtype,
                    device=device,
                )
                for i in indices
            ],
            dim=0,
        )

    def randn_like(self, tensor):
        size, dtype, device = tensor.size(), tensor.dtype, tensor.device
        return self.randn(*size, dtype=dtype, device=device)

    def set_done_samples(self, done_samples):
        self.done_samples = done_samples

    def get_seed(self):
        return self.seed

    def set_seed(self, seed):
        [rng_cpu.manual_seed(i + self.num_samples * seed) for i, rng_cpu in enumerate(self.rng_cpu)]
        if torch.cuda.is_available():
            [rng_cuda.manual_seed(i + self.num_samples * seed) for i, rng_cuda in enumerate(self.rng_cuda)]


class BatchedSeedGenerator:

    def __init__(self, seeds=None):
        self.num_samples = len(seeds)
        if torch.cuda.is_available():
            self.rng = [torch.Generator(dist_util.dev()) for _ in range(self.num_samples)]
        else:
            self.rng = [torch.Generator() for _ in range(self.num_samples)]
        [rng.manual_seed(int(seeds[i])) for i, rng in enumerate(self.rng)]

    def randn(self, size, dtype=torch.float, device="cpu"):
        assert size[0] == self.num_samples
        return torch.cat(
            [
                torch.randn(1, *size[1:], generator=self.rng[i], dtype=dtype, device=device)
                for i in range(self.num_samples)
            ],
            dim=0,
        )

    def randint(self, low, high, size, dtype=torch.long, device="cpu"):
        assert size[0] == self.num_samples
        return torch.cat(
            [
                torch.randint(
                    low,
                    high,
                    generator=self.rng[i],
                    size=(1, *size[1:]),
                    dtype=dtype,
                    device=device,
                )
                for i in range(self.num_samples)
            ],
            dim=0,
        )

    def randn_like(self, tensor):
        size, dtype, device = tensor.size(), tensor.dtype, tensor.device
        return self.randn(size, dtype=dtype, device=device)
