import torch as th
import torch.distributed as dist
from . import dist_util


def get_generator(generator, num_samples=0, seed=0):
    if generator == "dummy":
        return DummyGenerator()
    elif generator == "determ":
        return DeterministicGenerator(num_samples, seed)
    elif generator == "determ-indiv":
        return DeterministicIndividualGenerator(num_samples, seed)
    else:
        raise NotImplementedError


class DummyGenerator:
    def randn(self, *args, **kwargs):
        return th.randn(*args, **kwargs)

    def randint(self, *args, **kwargs):
        return th.randint(*args, **kwargs)

    def randn_like(self, *args, **kwargs):
        return th.randn_like(*args, **kwargs)


class DeterministicGenerator:
    """
    RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
    Uses a single rng and samples num_samples sized randomness and subsamples the current indices
    """

    def __init__(self, num_samples, seed=0):
        if dist.is_initialized():
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        else:
            print("Warning: Distributed not initialised, using single rank")
            self.rank = 0
            self.world_size = 1
        self.num_samples = num_samples
        self.done_samples = 0
        self.seed = seed
        self.rng_cpu = th.Generator()
        if th.cuda.is_available():
            self.rng_cuda = th.Generator(dist_util.dev())
        self.set_seed(seed)

    def get_global_size_and_indices(self, size):
        global_size = (self.num_samples, *size[1:])
        indices = th.arange(
            self.done_samples + self.rank,
            self.done_samples + self.world_size * int(size[0]),
            self.world_size,
        )
        indices = th.clamp(indices, 0, self.num_samples - 1)
        assert (
            len(indices) == size[0]
        ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
        return global_size, indices

    def get_generator(self, device):
        return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda

    def randn(self, *size, dtype=th.float, device="cpu"):
        global_size, indices = self.get_global_size_and_indices(size)
        generator = self.get_generator(device)
        return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[indices]

    def randint(self, low, high, size, dtype=th.long, device="cpu"):
        global_size, indices = self.get_global_size_and_indices(size)
        generator = self.get_generator(device)
        return th.randint(low, high, generator=generator, size=global_size, dtype=dtype, device=device)[indices]

    def randn_like(self, tensor):
        size, dtype, device = tensor.size(), tensor.dtype, tensor.device
        return self.randn(*size, dtype=dtype, device=device)

    def set_done_samples(self, done_samples):
        self.done_samples = done_samples
        self.set_seed(self.seed)

    def get_seed(self):
        return self.seed

    def set_seed(self, seed):
        self.rng_cpu.manual_seed(seed)
        if th.cuda.is_available():
            self.rng_cuda.manual_seed(seed)


class DeterministicIndividualGenerator:
    """
    RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
    Uses a separate rng for each sample to reduce memoery usage
    """

    def __init__(self, num_samples, seed=0):
        if dist.is_initialized():
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        else:
            print("Warning: Distributed not initialised, using single rank")
            self.rank = 0
            self.world_size = 1
        self.num_samples = num_samples
        self.done_samples = 0
        self.seed = seed
        self.rng_cpu = [th.Generator() for _ in range(num_samples)]
        if th.cuda.is_available():
            self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)]
        self.set_seed(seed)

    def get_size_and_indices(self, size):
        indices = th.arange(
            self.done_samples + self.rank,
            self.done_samples + self.world_size * int(size[0]),
            self.world_size,
        )
        indices = th.clamp(indices, 0, self.num_samples - 1)
        assert (
            len(indices) == size[0]
        ), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
        return (1, *size[1:]), indices

    def get_generator(self, device):
        return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda

    def randn(self, *size, dtype=th.float, device="cpu"):
        size, indices = self.get_size_and_indices(size)
        generator = self.get_generator(device)
        return th.cat(
            [
                th.randn(*size, generator=generator[i], dtype=dtype, device=device)
                for i in indices
            ],
            dim=0,
        )

    def randint(self, low, high, size, dtype=th.long, device="cpu"):
        size, indices = self.get_size_and_indices(size)
        generator = self.get_generator(device)
        return th.cat(
            [
                th.randint(
                    low,
                    high,
                    generator=generator[i],
                    size=size,
                    dtype=dtype,
                    device=device,
                )
                for i in indices
            ],
            dim=0,
        )

    def randn_like(self, tensor):
        size, dtype, device = tensor.size(), tensor.dtype, tensor.device
        return self.randn(*size, dtype=dtype, device=device)

    def set_done_samples(self, done_samples):
        self.done_samples = done_samples

    def get_seed(self):
        return self.seed

    def set_seed(self, seed):
        [
            rng_cpu.manual_seed(i + self.num_samples * seed)
            for i, rng_cpu in enumerate(self.rng_cpu)
        ]
        if th.cuda.is_available():
            [
                rng_cuda.manual_seed(i + self.num_samples * seed)
                for i, rng_cuda in enumerate(self.rng_cuda)
            ]
