import math
from typing import TypeVar, Optional, Iterator

import torch
from torch.utils.data import Sampler, Dataset
import torch.distributed as dist
import random
import numpy as np
import torch


class DistributedSamplerChunkByNode(torch.utils.data.Sampler):

    def __init__(self,
                 dataset,
                 all_datasets,
                 chunk_or_not,
                 num_replicas: Optional[int] = None,
                 rank: Optional[int] = None,
                 shuffle: bool = True,
                 seed: int = 0,
                 drop_last: bool = False,
                 node_rank=0,
                 node_number=1, process_num_per_node=1,
                 rank_within_local_node=0) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1))
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.node_number = node_number
        self.node_rank = node_rank
        self.chunk_or_not = chunk_or_not
        self.process_num_per_node = process_num_per_node
        self.rank_within_local_node = rank_within_local_node

        assert (self.process_num_per_node * self.node_number == self.num_replicas)

        # 1. divide the datasets into two parts
        normal_datasets = []
        chunked_datasets = []
        for dataset_i, chunk_i in zip(all_datasets, chunk_or_not):
            if chunk_i:
                chunked_datasets.append(dataset_i)
            else:
                normal_datasets.append(dataset_i)

        # 2. calculate dataset sizes:
        self.normal_dataset_size = sum(
            [len(i) for i in normal_datasets])  # this part we follow the conventional distributed sampler

        # 3. Divide 
        self.current_node_start_range = -1
        self.current_node_end_range = -1
        assert (len(chunked_datasets) >= self.node_number)
        chunk_size = len(chunked_datasets) // self.node_number
        current_example_num = self.normal_dataset_size

        for index in range(len(chunked_datasets)):
            if index == self.node_rank * chunk_size:
                self.current_node_start_range = current_example_num
            current_example_num += len(chunked_datasets[index])
            if index == (self.node_rank + 1) * chunk_size - 1:
                self.current_node_end_range = current_example_num

        if self.current_node_end_range == -1:  # boundary
            self.current_node_end_range = current_example_num

        self.drop_last = drop_last
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                # `type:ignore` is required because Dataset cannot provide a default __len__
                # see NOTE in pytorch/torch/utils/data/sampler.py
                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

    def __iter__(self):
        indices = self.generate_indices_within_range_with_rank(
            seed=self.seed,
            epoch=self.epoch,

            # NOTE: Distribute among all processes
            process_num=self.num_replicas,
            rank=self.rank,
            generate_length=-1,
            valid_indices=list(range(self.normal_dataset_size)),
            prefix="Normal "
        )

        addition_indices = self.generate_indices_within_range_with_rank(
            seed=self.seed,
            epoch=self.epoch,

            # NOTE : very important arguments, distribute among local nodes
            process_num=self.process_num_per_node,
            rank=self.rank_within_local_node,

            generate_length=self.num_samples - len(indices),
            valid_indices=list(range(self.current_node_start_range, self.current_node_end_range)),
            prefix="Distribute "
        )

        indices.extend(addition_indices)
        random.seed(self.seed + self.epoch + 10 * self.rank)  # Set the seed to maximize randomness
        random.shuffle(indices)  # Reshuffle
        assert len(indices) == self.num_samples
        return iter(indices)

    def generate_indices_within_range_with_rank(self, seed, epoch, process_num, generate_length, valid_indices, rank=-1,
                                                shuffle=True, prefix=""):
        '''
        Use scenario : we want to sample 2500 examples from 10000 examples, while not sampling overlapping examples with other three process.
        Modified from DistributedSampler
        '''
        dataset_size = len(valid_indices)
        if shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(seed + epoch)
            indices = torch.randperm(dataset_size, generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(dataset_size))  # type: ignore[arg-type]

        indices = [valid_indices[i] for i in indices]

        num_samples_normal = math.ceil(
            (dataset_size - process_num) / process_num  # type: ignore[arg-type]
        )
        # remove tail of data to make it evenly divisible.
        indices = indices[:num_samples_normal * process_num]

        print("\n")
        print(prefix,
              "Global Rank {}   Local Rank {}    generate_length {}    valid_indices {}    process_num {}  indices_before_subsample {} {}".format(
                  self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10]))

        # subsample
        indices = indices[rank:num_samples_normal * process_num: process_num]

        print(prefix,
              "Global Rank {}   Local Rank {}    generate_length {}    valid_indices {}    process_num {}  indices_after_subsample {} {}".format(
                  self.rank, rank, generate_length, len(valid_indices), process_num, len(indices), indices[:10]))
        print("\n")

        if generate_length != -1:
            if len(indices) > generate_length:
                indices = indices[:generate_length]
            else:
                indices.extend(np.random.choice(valid_indices, generate_length - len(indices)).tolist())
        return indices

    def __len__(self) -> int:
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.

        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch
