from datasets import Dataset
import torch
from collections import defaultdict, namedtuple
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler, BatchSampler
from torch.utils.data.distributed import DistributedSampler
from typing import Optional, Sized, Iterator, T_co
import random


# TODO: not working yet

class _RandomSampler(Sampler):
    def __init__(self, data_source: Optional[Sized]) -> None:
        super().__init__(data_source)
        self.data_source = data_source

    def __len__(self) -> int:
        return len(self.data_source)

    def __iter__(self) -> Iterator[T_co]:
        data = [x for x in self.data_source]
        random.shuffle(data)
        return iter(data)


class MTDataset:
    def __init__(self, datasets):

        # datasets_list = defaultdict(list)
        # for ex in datasets:
        #     assert 'task' in ex
        #     datasets_list[ex['task']].append(ex)
        # datasets = list(datasets_list.values())

        assert type(datasets) == list
        self.datasets = datasets
        self.lengths = [len(ds) for ds in datasets]

    def __len__(self):
        return sum(self.lengths)

    def __getitem__(self, index):
        task_index, example_index = index
        ex = self.datasets[task_index][example_index]
        ex["index"] = index
        return ex



class MTBatchedDistributedSampler(Sampler):
    """
        Each batch contains a batch of examples from the same task (domain).
        Cycles between tasks across batches.
    """

    def __init__(
            self,
            dataset: Dataset,
            batch_size: int,
            drop_last_batch: Optional[bool] = False,  # Will drop last incomplete
            # batch of each task.
            num_replicas: Optional[int] = None,
            rank: Optional[int] = None,
            shuffle: bool = True,
            seed: int = 0) -> None:

        def _create_sampler(ds):
            dist_sampler = DistributedSampler(
                ds, num_replicas=num_replicas,
                rank=rank,
                shuffle=False,
                seed=seed,
                drop_last=False
            )
            if shuffle:
                return BatchSampler(
                    _RandomSampler(dist_sampler),
                    batch_size=batch_size,
                    drop_last=drop_last_batch
                )
            else:
                return BatchSampler(
                    dist_sampler,
                    batch_size=batch_size,
                    drop_last=drop_last_batch
                )

        # assert isinstance(dataset, MTDataset)
        self.dataset = dataset
        self.num_replicas = num_replicas,
        self.rank = rank
        self.shuffle = shuffle
        self.seed = seed
        self.batch_size = batch_size
        self.drop_last = drop_last_batch
        # Create distributed samplers for each task.
        # samples is the collection of task specific samples sampled from the 
        # sampler.
        self.task_samplers = [
            _create_sampler(ds)
            for ds in dataset.datasets
        ]
        # print('task_samplers: ',self.task_samplers)

        # print('_create_sampler: ')

        self.task_samples = [
            len(sampler)
            for sampler in self.task_samplers
        ]

    def __iter__(self):
        iters = [
            iter(sampler)
            for sampler in self.task_samplers
        ]
        task_indices = list(range(len(iters)))
        rem_batches = list(self.task_samples)
        task_idx = 0
        while len(rem_batches) > 0:

            print('task_indices: ',task_indices)
            print('rem_batches: ',rem_batches)



            batch = next(iters[task_idx])
            batch = [(task_indices[task_idx], idx) for idx in batch]
            rem_batches[task_idx] -= 1
            if rem_batches[task_idx] == 0:
                del rem_batches[task_idx]
                del task_indices[task_idx]
                del iters[task_idx]
            else:
                task_idx = (task_idx + 1) % len(rem_batches)
            yield batch


    def __len__(self):
        return sum(self.task_samples)

