import logging
from typing import Iterator

import torch
from datasets import Dataset
from torch.utils.data import WeightedRandomSampler


logger = logging.getLogger(__name__)


class CustomWeightedRandomSampler(WeightedRandomSampler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reload_rand_tensor = False

    def set_weights(self, weights):
        if isinstance(weights, torch.Tensor):
            weights_tensor = weights
        else:
            weights_tensor = torch.as_tensor(weights, dtype=torch.double)
        # TODO check sum here?
        if len(weights_tensor.shape) != 1:
            raise ValueError(
                "weights should be a 1d sequence but given " f"weights have shape {tuple(weights_tensor.shape)}"
            )
        self.weights = weights_tensor
        self.reload_rand_tensor = True

    def __iter__(self) -> Iterator[int]:
        # TODO: The new weight should be normalized ? because we need to change the number of samples ?
        rand_tensor = torch.multinomial(
            self.weights,
            self.num_samples,
            False,  # No replacement, we want to see all the samples before the first update
            generator=self.generator,
        ).tolist()

        for idx in range(self.num_samples):
            if self.reload_rand_tensor:
                rand_tensor = torch.multinomial(
                    self.weights,
                    self.num_samples,
                    self.replacement,
                    generator=self.generator,
                ).tolist()
                logger.debug(f"reload_rand_tensor: {self.reload_rand_tensor}")
                logger.debug(f"rand_tensor: {rand_tensor[:5]}")
                logger.debug(f"weights: {self.weights.tolist()[:5]}")
                logger.info("Weights updated")
                self.reload_rand_tensor = False

            yield rand_tensor[idx]


def create_weighted_sampler(*, dataset: Dataset, **kwargs) -> WeightedRandomSampler:
    assert dataset is not None, "Train dataset is None"

    # Uniform weights
    weights = torch.ones(len(dataset), dtype=torch.double) / len(dataset)

    assert torch.isclose(torch.sum(weights), torch.tensor(1.0, dtype=torch.double)), "Weights do not sum to 1"
    logger.info("Weighted sampler created with uniform weights")

    return CustomWeightedRandomSampler(weights, len(dataset))
