import torch
import torchvision
from torchvision import datasets
from torch.utils.data import sampler, DataLoader
from torch.utils.data.sampler import BatchSampler
import torch.distributed as dist
import numpy as np
import json
import os
import math
from torch.utils.data import Dataset, DataLoader, RandomSampler, ConcatDataset


class DistributedBatchSchedulerSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, dataset, batch_size, num_replicas, rank):
        """Initialize the sampler with dataset and distributed parameters."""
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = len(dataset.datasets)
        self.largest_dataset_size = max([len(cur_dataset) for cur_dataset in dataset.datasets])
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        # Number of steps to cover the largest dataset
        self.num_steps = math.floor(self.largest_dataset_size / (self.batch_size * self.num_replicas))
        # Total samples per process
        self.num_samples = self.num_steps * self.number_of_datasets * self.batch_size
        # Total samples across all processes
        self.total_size = self.num_samples * self.num_replicas

    def __len__(self):
        """Return the number of samples per process."""
        return self.num_samples

    def __iter__(self):
        """Generate and subsample indices for the current process."""
        indices = self.generate_indices()
        
        # Ensure total size matches expected length across all processes
        if len(indices) < self.total_size:
            indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # Subsample indices for this process
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def generate_indices(self):
        """Generate indices ensuring each batch per process is from one dataset."""
        # Cumulative sizes for index adjustment
        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        
        # Initialize samplers and iterators for each dataset
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)

        # Set up random generator with epoch seed for consistency across processes
        g = torch.Generator()
        g.manual_seed(self.epoch)

        final_samples_list = []
        # Generate indices for each step
        for _ in range(self.num_steps):
            # Randomly permute datasets for this step
            dataset_perm = torch.randperm(self.number_of_datasets, generator=g).tolist()
            for dataset_idx in dataset_perm:
                cur_iterator = sampler_iterators[dataset_idx]
                # For each process, sample batch_size indices
                for _ in range(self.num_replicas):
                    cur_samples = []
                    for _ in range(self.batch_size):
                        try:
                            cur_sample_org = cur_iterator.__next__()
                        except StopIteration:
                            # Restart iterator if exhausted
                            sampler_iterators[dataset_idx] = samplers_list[dataset_idx].__iter__()
                            cur_iterator = sampler_iterators[dataset_idx]
                            cur_sample_org = cur_iterator.__next__()
                        # Adjust index for concatenated dataset
                        cur_sample = cur_sample_org + push_index_val[dataset_idx]
                        cur_samples.append(cur_sample)
                    final_samples_list.extend(cur_samples)

        return final_samples_list

    def set_epoch(self, epoch):
        """Set the epoch to update the random seed."""
        self.epoch = epoch


class BatchSchedulerSampler(torch.utils.data.sampler.RandomSampler):
    """
    iterate over tasks and provide a random batch per task in each mini-batch
    """
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = len(dataset.datasets)
        self.largest_dataset_size = max([len(cur_dataset) for cur_dataset in dataset.datasets])
        
    def __len__(self):
        return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)

    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset)
            # sampler = DistributedSampler(cur_dataset, shuffle=True)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)

        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size
        # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
        epoch_samples = self.largest_dataset_size * self.number_of_datasets

        final_samples_list = []  # this is a list of indexes from the combined dataset
        clt_1 = []
        clt_2 = []
        for _ in range(0, epoch_samples, step):
            for i in np.random.permutation(self.number_of_datasets):
            # for i in np.array([0, 0, 1, 1, 2, 2, 3, 3]):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab): # batch with one task/dataset
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        # got to the end of iterator - restart the iterator and continue to get samples
                        # until reaching "epoch_samples"
                        sampler_iterators[i] = samplers_list[i].__iter__()
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                if i==0:
                    clt_1.extend(cur_samples)
                else:
                    clt_2.extend(cur_samples)
                final_samples_list.extend(cur_samples)

        return iter(final_samples_list)


def get_dataloader_multidataset(
        dataset,
        batch_size=None,
        num_workers=4,
        pin_memory=False,
        distributed=False,
    ):
    """
    get_dataloader returns torch.utils.data.DataLoader for a Dataset.
    All arguments are comparable with those of pytorch DataLoader.
    However, if distributed, DistributedProxySampler, which is a wrapper of data_sampler, is used.
    
    Args
        num_epochs: total batch -> (# of batches in dset) * num_epochs 
        num_iters: total batch -> num_iters
    """

    assert batch_size is not None

    if distributed:
        assert dist.is_available()
        num_replicas = dist.get_world_size()
    else:
        num_replicas = 1

    if distributed:
        batch_sampler = DistributedBatchSchedulerSampler(dataset, batch_size, num_replicas, dist.get_rank())
    else:
        batch_sampler = BatchSchedulerSampler(dataset, batch_size)
    
    return DataLoader(dataset, sampler=batch_sampler, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, prefetch_factor=4, persistent_workers=True)


def get_single_dataloader_dict(datasets, batch_size=None, num_workers=4, pin_memory=False):
    dataloader_dict = {}
    for dataset_name, dataset in datasets.items():
        dataloader_dict[dataset_name] = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
    return dataloader_dict

