import math
import numpy as np
from typing import Iterator, Optional
import torch
from torch.utils.data.dataloader import _BaseDataLoaderIter
from torch.utils.data import Dataset, _DatasetKind
from torch.utils.data.distributed import DistributedSampler
from operator import itemgetter
import torch.distributed as dist
import warnings
from collections import defaultdict

__all__ = ['OrderDP']


def info_hack_indices(self):
    with torch.autograd.profiler.record_function(self._profile_name):
        if self._sampler_iter is None:
            self._reset()  # type: ignore[call-arg]
        if isinstance(self._dataset, OrderDP):
            indices, data = self._next_data()
        else:
            data = self._next_data()
        self._num_yielded += 1
        if self._dataset_kind == _DatasetKind.Iterable and \
                self._IterableDataset_len_called is not None and \
                self._num_yielded > self._IterableDataset_len_called:
            warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                        "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                self._num_yielded)
            if self._num_workers > 0:
                warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                "IterableDataset replica at each worker.")
            warnings.warn(warn_msg)
        if isinstance(self._dataset, OrderDP):
            self._dataset.set_active_indices(indices)
        return data


_BaseDataLoaderIter.__next__ = info_hack_indices


@torch.no_grad()
def concat_all_gather(tensor, dim=0):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
                      for _ in range(dist.get_world_size())]
    dist.all_gather(tensors_gather, tensor, async_op=False)
    output = torch.cat(tensors_gather, dim=dim)
    return output


class OrderDP(Dataset):
    """
    OrderDP aims to achieve lossless training speed up by randomly prunes a portion of less informative samples
    based on the loss distribution and rescales the gradients of the remaining samples to approximate the original
    gradient.

    .. note::.
        Dataset is assumed to be of constant size.

    Args:
        dataset: Dataset used for training.
        num_epochs (int): The number of epochs for pruning.
        prune_ratio (float, optional): The proportion of samples being pruned during training.
        delta (float, optional): The first delta * num_epochs the pruning process is conducted. It should be close to 1. Defaults to 0.875.
    """

    def __init__(self, dataset: Dataset, num_epochs: int, random_len_ratio: float = 0.75, top_q_ratio: float = 0.85,
                 prune_ratio: float = 0.5, delta: float = 0.875):
        self.dataset = dataset
        self.total = len(self.dataset)
        self.random_len = int(random_len_ratio * len(self.dataset))
        self.top_q = int(top_q_ratio * self.random_len)
        self.keep_ratio = min(1.0, max(1e-1, 1.0 - prune_ratio))
        self.num_epochs = num_epochs
        self.delta = delta
        # self.scores stores the loss value of each sample. Note that smaller value indicates the sample is better learned by the network.
        self.scores = torch.ones(len(self.dataset)) * 3
        #self.weights = torch.ones(len(self.dataset))
        #self.weights = torch.tensor([self.compute_expression(self.total, self.random_len, j, self.top_q) for j in range(1,self.total + 1)])
        self.num_pruned_samples = 0
        self.sample_usage = defaultdict(int)
        #self.tensor_dict = None
        self.cur_batch_index = None

    # def __getattr__(self, name):
    #     # Delegate the method call to the self.dataset if it is not found in Wrapper
    #     return getattr(self.dataset, name)

    def set_active_indices(self, cur_batch_indices: torch.Tensor):
        self.cur_batch_index = cur_batch_indices

    def update(self, values):
        assert isinstance(values, torch.Tensor)
        batch_size = values.shape[0]
        #print(batch_size, len(self.cur_batch_index))
        assert len(self.cur_batch_index) == batch_size, 'not enough index'
        device = values.device

        # keys_found = []
        # for value in self.cur_batch_index:
        #     for key, val in self.tensor_dict.items():
        #         if val == value:
        #             keys_found.append(key)
        #weights = self.weights[keys_found].to(device)
        #weights = self.weights[self.cur_batch_index].to(device)
        indices = self.cur_batch_index.to(device)
        loss_val = values.detach().clone()
        self.cur_batch_index = []

        if dist.is_available() and dist.is_initialized():
            print('use multi gpu')
            iv = torch.cat([indices.view(1, -1), loss_val.view(1, -1)], dim=0)
            iv_whole_group = concat_all_gather(iv, 1)
            indices = iv_whole_group[0]
            loss_val = iv_whole_group[1]
        self.scores[indices.cpu().long()] = loss_val.cpu()
        #values.mul_(weights)
        return values.mean()

    # def factorial(self, n):
    #     return torch.prod(torch.arange(1, n + 1))


    # def combination(self, n, k):
    #     return self.factorial(n) / (self.factorial(k) * self.factorial(n - k))


    # def compute_expression(self, n, s, j, q):
    #     # 计算组合数
    #     comb_n_1_s_1 = self.combination(n - 1, s - 1)
    #     comb_n_s = self.combination(n, s)

    #     # 计算分子中的求和
    #     max_l = max(1, s - n + j)
    #     min_l = min(q, j)
    #     sum_numerator = sum(self.combination(j - 1, l - 1) * self.combination(n - j, s - l) for l in range(max_l, min_l + 1))

    #     # 计算分母中的求和
    #     max_l = max(1, s - n + j)
    #     min_l = min(s, j)
    #     sum_denominator = sum(self.combination(j - 1, l - 1) * self.combination(n - j, s - l) for l in range(max_l, min_l + 1))

    #     # 组合所有部分
    #     result = (comb_n_1_s_1 / comb_n_s) * (sum_numerator / sum_denominator)
    #     return result

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        # self.cur_batch_index.append(index)
        return index, self.dataset[index] # , index
        # return self.dataset[index], index, self.scores[index]

    def prune(self):
        # Prune samples that are well learned, rebalance the weight by scaling up remaining
        # well learned samples' learning rate to keep estimation about the same
        # for the next version, also consider new class balance
        selected_indices = np.random.choice(len(self.dataset), 
            self.random_len, replace=False)
        if self.top_q == self.random_len:
            result_indices = selected_indices
        else: 
            top_indices = torch.topk(self.scores[selected_indices],self.top_q, dim=0, largest=True, sorted=False)[1]
            result_indices = selected_indices[top_indices]
        #sorted_indices = torch.argsort(self.scores[selected_indices], descending=True)
        
        #result_indices = selected_indices[sorted_indices[:self.top_q]]
        
        #self.tensor_dict = {i: result_indices[i] for i in range(len(result_indices))}

        # well_learned_mask = (self.scores < self.scores.mean()).numpy()
        # well_learned_indices = np.where(well_learned_mask)[0]
        # remained_indices = np.where(~well_learned_mask)[0].tolist()
        # # print('#well learned samples %d, #remained samples %d, len(dataset) = %d' % (np.sum(well_learned_mask), np.sum(~well_learned_mask), len(self.dataset)))
        # selected_indices = np.random.choice(well_learned_indices, int(
        #     self.keep_ratio * len(well_learned_indices)), replace=False)
        # self.reset_weights()
        # if len(selected_indices) > 0:
        #     self.weights[selected_indices] = 1 / self.keep_ratio
        #     remained_indices.extend(selected_indices)
        self.num_pruned_samples += len(self.dataset) - len(result_indices)
        for idx in result_indices:
            self.sample_usage[idx] += 1
        np.random.shuffle(result_indices)

        return result_indices

    @property
    def sampler(self):
        sampler = IBSampler(self)
        print(dist.is_available(),dist.is_initialized() )
        # if dist.is_available() and dist.is_initialized():
        #     sampler = DistributedIBSampler(sampler)
        return sampler

    def no_prune(self):
        samples_indices = list(range(len(self)))
        np.random.shuffle(samples_indices)
        return samples_indices

    def mean_score(self):
        return self.scores.mean()

    def get_weights(self, indexes):
        return self.weights[indexes]
    
    def get_sample_usage(self):
        return self.sample_usage

    def get_pruned_count(self):
        return self.num_pruned_samples

    @property
    def stop_prune(self):
        return self.num_epochs * self.delta

    def reset_weights(self):
        self.weights[:] = 1


class IBSampler(object):
    def __init__(self, dataset: OrderDP):
        self.dataset = dataset
        self.stop_prune = dataset.stop_prune
        self.iterations = 0
        self.sample_indices = None
        self.iter_obj = None
        self.reset()

    def __getitem__(self, idx):
        return self.sample_indices[idx]

    def reset(self):
        np.random.seed(self.iterations)
        if self.iterations > self.stop_prune:  
            # print('we are going to stop prune, #stop prune %d, #cur iterations %d' % (self.iterations, self.stop_prune))
            # if self.iterations == self.stop_prune + 1:
            #     self.dataset.reset_weights()
            #self.sample_indices = self.dataset.no_prune()
            self.sample_indices = self.dataset.prune()
        else:
            # print('we are going to continue pruning, #stop prune %d, #cur iterations %d' % (self.iterations, self.stop_prune))
            self.sample_indices = self.dataset.prune()
        self.iter_obj = iter(self.sample_indices)
        self.iterations += 1

    def __next__(self):
        return next(self.iter_obj) # may raise StopIteration
        
    def __len__(self):
        return len(self.sample_indices)

    def __iter__(self):
        self.reset()
        return self


class DistributedIBSampler(DistributedSampler):
    """
    Wrapper over `Sampler` for distributed training.
    Allows you to use any sampler in distributed mode.
    It is especially useful in conjunction with
    `torch.nn.parallel.DistributedDataParallel`. In such case, each
    process can pass a DistributedSamplerWrapper instance as a DataLoader
    sampler, and load a subset of subsampled data of the original dataset
    that is exclusive to it.
    .. note::
        Sampler can change size during training.
    """
    class DatasetFromSampler(Dataset):
        def __init__(self, sampler: IBSampler):
            self.dataset = sampler
            # self.indices = None
 
        def reset(self, ):
            self.indices = None
            self.dataset.reset()

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, index: int):
            """Gets element of the dataset.
            Args:
                index: index of the element in the dataset
            Returns:
                Single element by index
            """
            # if self.indices is None:
            #    self.indices = list(self.dataset)
            return self.dataset[index]

    def __init__(self, dataset: IBSampler, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = True) -> None:
        sampler = self.DatasetFromSampler(dataset)
        super(DistributedIBSampler, self).__init__(
            sampler, num_replicas, rank, shuffle, seed, drop_last)
        self.sampler = sampler
        self.dataset = sampler.dataset.dataset # the real dataset.
        self.iter_obj = None

    def __iter__(self) -> Iterator[int]:
        """
        Notes self.dataset is actually an instance of IBSampler rather than OrderDP.
        """
        self.sampler.reset()
        if self.drop_last and len(self.sampler) % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.sampler) - self.num_replicas) /
                self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.num_samples = math.ceil(
                len(self.sampler) / self.num_replicas)  # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas

        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            # type: ignore[arg-type]
            indices = torch.randperm(len(self.sampler), generator=g).tolist()
        else:
            indices = list(range(len(self.sampler)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size /
                            len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size
        indices = indices[self.rank:self.total_size:self.num_replicas]
        # print('distribute iter is called')
        self.iter_obj = iter(itemgetter(*indices)(self.sampler))
        return self.iter_obj
   
