import math
import numpy as np
from typing import Iterator, Optional
import torch
from torch.utils.data.dataloader import _BaseDataLoaderIter
from torch.utils.data import Dataset, _DatasetKind
from torch.utils.data.distributed import DistributedSampler
from operator import itemgetter
import torch.distributed as dist
import warnings

__all__ = ['HistoricSampler']


def info_hack_indices(self):
    with torch.autograd.profiler.record_function(self._profile_name):
        if self._sampler_iter is None:
            # TODO(https://github.com/pytorch/pytorch/issues/76750)
            self._reset()  # type: ignore[call-arg]
        if isinstance(self._dataset, HistoricSampler):
            indices, data = self._next_data()
        else:
            data = self._next_data()
        self._num_yielded += 1
        if self._dataset_kind == _DatasetKind.Iterable and \
                self._IterableDataset_len_called is not None and \
                self._num_yielded > self._IterableDataset_len_called:
            warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                        "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                self._num_yielded)
            if self._num_workers > 0:
                warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                "IterableDataset replica at each worker. Please see "
                                "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
            warnings.warn(warn_msg)
        if isinstance(self._dataset, HistoricSampler):
            self._dataset.set_active_indices(indices)
            return indices, *data
        else:
            return data


_BaseDataLoaderIter.__next__ = info_hack_indices


@torch.no_grad()
def concat_all_gather(tensor, dim=0):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
                      for _ in range(dist.get_world_size())]
    dist.all_gather(tensors_gather, tensor, async_op=False)
    output = torch.cat(tensors_gather, dim=dim)
    return output


class HistoricSampler(Dataset):
    """
    HistoricSampler aims to achieve lossless training speed up by randomly prunes a portion of less informative samples
    based on the loss distribution and rescales the gradients of the remaining samples to approximate the original
    gradient. See https://arxiv.org/pdf/2303.04947.pdf

    .. note::.
        Dataset is assumed to be of constant size.

    Args:
        dataset: Dataset used for training.
        num_epochs (int): The number of epochs for pruning.
        prune_ratio (float, optional): The proportion of samples being pruned during training.
        delta (float, optional): The first delta * num_epochs the pruning process is conducted. It should be close to 1. Defaults to 0.875.
    """

    def __init__(self, dataset: Dataset, budget_num_epochs: int, historic_beta: float, smooth_factor: float, sampling_generator, full_sample_freq: int = -1,
                 prune_ratio: float = 0.5, delta: float = 0.875, warmup=-1, dynamic_ratio=False, heuristic_ratio=False, count_discount=1, heuristic_max=False):
        self.dataset = dataset
        self.num_samples = len(self.dataset)
        self.peak_prune_ratio = prune_ratio
        self.warmup = warmup
        if warmup == -1:
            self.num_epochs = budget_num_epochs/(1-prune_ratio)
        else:
            self.num_epochs = budget_num_epochs/(1+(1-prune_ratio))*2
        self.delta = delta
        self.historic_beta = historic_beta
        self.count_discount = count_discount
        self.smooth_factor = smooth_factor
        self.sampling_generator = sampling_generator
        self.full_sample_freq = full_sample_freq
        self.dynamic_ratio = dynamic_ratio
        self.heuristic_ratio = heuristic_ratio
        self.heuristic_max = heuristic_max
        # self.scores stores the loss value of each sample. Note that smaller value indicates the sample is better learned by the network.
        #self.scores = torch.zeros(len(self.dataset))
        self.weights = torch.ones(len(self.dataset))
        self.num_pruned_samples = 0
        self.cur_batch_index = None

        self.history_approx_grad_sq = torch.zeros(self.num_samples, device="cuda")
        self.history_count = torch.zeros(self.num_samples, device="cuda", dtype=torch.int)
        self.correct = torch.zeros(self.num_samples, device="cuda")
        self.grad_sq_ema = 0
        self.grad_sq_ema_count = 0

    def set_active_indices(self, cur_batch_indices: torch.Tensor):
        self.cur_batch_index = cur_batch_indices

    def update(self, indices, output, values, labels=None):
        assert isinstance(values, torch.Tensor)
        batch_size = values.shape[0]
        #assert len(self.cur_batch_index) == batch_size, 'not enough index'
        if labels is not None:
            _, pred = output.topk(k=1, dim=-1)
            pred = pred.t()
            correct = pred.eq(labels)
            self.correct[indices] = correct.float()
        device = values.device
        weights = self.weights[indices].to(device)

        sample_approx_grad = torch.linalg.norm(torch.autograd.grad(values.sum(), output, retain_graph=True)[0], dim=-1).detach()/torch.sqrt(torch.tensor(2))
        assert (sample_approx_grad<=1).all()

        # TODO check
        grad_sq_mean = (sample_approx_grad**2*weights).mean()
        self.grad_sq_ema = 0.99*self.grad_sq_ema + 0.01*grad_sq_mean
        self.grad_sq_ema_count += 1

        grad_sq_ema = self.grad_sq_ema/(1-0.99**self.grad_sq_ema_count)
        self.history_approx_grad_sq[indices] = self.historic_beta*self.history_approx_grad_sq[indices]+(1-self.historic_beta)*(sample_approx_grad**2/grad_sq_ema)

        self.history_count[indices] += torch.tensor(1, device="cuda")

        values.mul_(weights)
        return values.mean()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        # self.cur_batch_index.append(index)
        return index, self.dataset[index] # , index
        # return self.dataset[index], index, self.scores[index]

    def prune(self, iterations):
        # Prune samples that are well learned, rebalance the weight by scaling up remaining
        # well learned samples' learning rate to keep estimation about the same
        # for the next version, also consider new class balance

        if self.warmup == -1:
            keep_ratio = 1-self.peak_prune_ratio
        else:
            peak = int(self.num_epochs*self.warmup)
            if iterations <= peak:
                keep_ratio = (iterations * (1-self.peak_prune_ratio) + (peak-iterations) * 1)/peak
            else:
                keep_ratio = 1-self.peak_prune_ratio
                #keep_ratio = ( (self.num_epochs-iterations-1) * (1-self.peak_prune_ratio) + (iterations-peak) * 1 )/(self.num_epochs-peak-1)
            keep_ratio = min(max(keep_ratio, 0), 1)
        if self.dynamic_ratio:
            correct_ratio = self.correct.mean()
            keep_ratio = (1-correct_ratio) + correct_ratio*self.peak_prune_ratio

        sample_approx_grad_sq = self.history_approx_grad_sq
        sample_approx_grad_sq = sample_approx_grad_sq/(1-torch.pow(self.historic_beta, self.history_count))
        sample_approx_grad_sq[self.history_count==0] = 1

        c = self.history_count-self.history_count.min()
        sample_approx_grad_sq *= torch.pow(self.count_discount, c)

        sample_approx_grad_sq = (1-self.smooth_factor)*sample_approx_grad_sq + self.smooth_factor*sample_approx_grad_sq.mean()
        sample_approx_grad = torch.sqrt(sample_approx_grad_sq)
        sample_approx_grad = sample_approx_grad.double()

        if self.heuristic_ratio:
            ref = torch.sqrt(sample_approx_grad_sq.mean())
        elif self.heuristic_max:
            ref = sample_approx_grad.max()*0.9
        else:
            lb = 0
            # TODO check this
            rb = sample_approx_grad.mean()/keep_ratio
            for _ in range(50):
                mid = (lb+rb)/2

                _prob = sample_approx_grad/mid
                _prob = torch.minimum(_prob, torch.tensor(1))

                if _prob.mean() < keep_ratio:
                    rb = mid
                else:
                    lb = mid
            ref = mid

        sample_prob = sample_approx_grad/ref
        sample_prob = torch.minimum(sample_prob, torch.tensor(1))
        keep_ratio = sample_prob.mean()

        # TODO check this
        self.weights = 1/sample_prob*keep_ratio

        selected_indices = np.arange(self.num_samples)[torch.rand(self.num_samples, generator=self.sampling_generator) < sample_prob.cpu()]

        # in-place shuffle
        np.random.shuffle(selected_indices)
        return selected_indices

        """
        well_learned_mask = (self.scores < self.scores.mean()).numpy()
        well_learned_indices = np.where(well_learned_mask)[0]
        remained_indices = np.where(~well_learned_mask)[0].tolist()
        # print('#well learned samples %d, #remained samples %d, len(dataset) = %d' % (np.sum(well_learned_mask), np.sum(~well_learned_mask), len(self.dataset)))
        selected_indices = np.random.choice(well_learned_indices, int(
            self.keep_ratio * len(well_learned_indices)), replace=False)
        self.reset_weights()
        if len(selected_indices) > 0:
            self.weights[selected_indices] = 1 / self.keep_ratio
            remained_indices.extend(selected_indices)
        self.num_pruned_samples += len(self.dataset) - len(remained_indices)
        np.random.shuffle(remained_indices)
        return remained_indices
        """

    @property
    def sampler(self):
        sampler = IBSampler(self)
        if dist.is_available() and dist.is_initialized():
            sampler = DistributedIBSampler(sampler)
        return sampler

    def no_prune(self):
        samples_indices = list(range(len(self)))
        np.random.shuffle(samples_indices)
        return samples_indices

    def mean_score(self):
        return self.scores.mean()

    def get_weights(self, indexes):
        return self.weights[indexes]

    def get_pruned_count(self):
        return self.num_pruned_samples

    @property
    def stop_prune(self):
        return self.num_epochs * self.delta

    def reset_weights(self):
        self.weights[:] = 1


class IBSampler(object):
    def __init__(self, dataset: HistoricSampler):
        self.dataset = dataset
        self.stop_prune = dataset.stop_prune
        self.iterations = 0
        self.sample_indices = None
        self.iter_obj = None
        self.full_sample_freq = dataset.full_sample_freq
        #self.reset()

    def __getitem__(self, idx):
        return self.sample_indices[idx]

    def reset(self):
        """
        np.random.seed(self.iterations)
        if self.iterations > self.stop_prune:
            # print('we are going to stop prune, #stop prune %d, #cur iterations %d' % (self.iterations, self.stop_prune))
            if self.iterations == self.stop_prune + 1:
                self.dataset.reset_weights()
            self.sample_indices = self.dataset.no_prune()
        elif self.full_sample_freq > 0 and self.iterations%self.full_sample_freq == 0:
            self.dataset.reset_weights()
            self.sample_indices = self.dataset.no_prune()
        else:
        """
            # print('we are going to continue pruning, #stop prune %d, #cur iterations %d' % (self.iterations, self.stop_prune))
        self.sample_indices = self.dataset.prune(self.iterations)
        self.iter_obj = iter(self.sample_indices)
        self.iterations += 1

    def __next__(self):
        return next(self.iter_obj) # may raise StopIteration
        
    def __len__(self):
        return len(self.sample_indices)

    def __iter__(self):
        self.reset()
        return self


class DistributedIBSampler(DistributedSampler):
    """
    Wrapper over `Sampler` for distributed training.
    Allows you to use any sampler in distributed mode.
    It is especially useful in conjunction with
    `torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSamplerWrapper instance as a DataLoader
    sampler, and load a subset of subsampled data of the original dataset
    that is exclusive to it.
    .. note::
        Sampler can change size during training.
    """
    class DatasetFromSampler(Dataset):
        def __init__(self, sampler: IBSampler):
            self.dataset = sampler
            # self.indices = None
 
        def reset(self, ):
            self.indices = None
            self.dataset.reset()

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, index: int):
            """Gets element of the dataset.
            Args:
                index: index of the element in the dataset
            Returns:
                Single element by index
            """
            # if self.indices is None:
            #    self.indices = list(self.dataset)
            return self.dataset[index]

    def __init__(self, dataset: IBSampler, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = True) -> None:
        sampler = self.DatasetFromSampler(dataset)
        super(DistributedIBSampler, self).__init__(
            sampler, num_replicas, rank, shuffle, seed, drop_last)
        self.sampler = sampler
        self.dataset = sampler.dataset.dataset # the real dataset.
        self.iter_obj = None

    def __iter__(self) -> Iterator[int]:
        """
        Notes self.dataset is actually an instance of IBSampler rather than HistoricSampler.
        """
        self.sampler.reset()
        if self.drop_last and len(self.sampler) % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.sampler) - self.num_replicas) /
                self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.num_samples = math.ceil(
                len(self.sampler) / self.num_replicas)  # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas

        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            # type: ignore[arg-type]
            indices = torch.randperm(len(self.sampler), generator=g).tolist()
        else:
            indices = list(range(len(self.sampler)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size /
                            len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size
        indices = indices[self.rank:self.total_size:self.num_replicas]
        # print('distribute iter is called')
        self.iter_obj = iter(itemgetter(*indices)(self.sampler))
        return self.iter_obj
   
