import numpy as np
from collections import defaultdict
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset, DistributedSampler, \
                            WeightedRandomSampler, Sampler, Subset, ConcatDataset
from torchvision import datasets, transforms

from utils.augmentations import get_transforms
from utils.dataset import SimCLRDataset


def get_dataset(dataset_name, dataset_path,
                augment_both_views=True,
                batch_size=64, num_workers=8, 
                shuffle=True, **kwargs):
    
    multi_gpu = kwargs.pop('multi_gpu', False)
    world_size = kwargs.pop('world_size', 1)
    supervision = kwargs.pop('supervision', 'SSL')
    test = kwargs.pop('test', None)
    classes = kwargs.pop('classes', None)
    
    if dataset_name is None:
        # default to cifar10
        dataset_name = 'cifar10'

    if dataset_name == 'imagenet':
        dataset_download = load_dataset('timm/mini-imagenet')
        train_dataset_download = dataset_download['train']
        test_dataset_download = dataset_download['test']
        train_transforms, basic_transforms = get_transforms('imagenet')
        num_workers = 32
        labels = np.array(train_dataset_download['label'])
        labels_test = np.array(test_dataset_download['label'])
        
    elif dataset_name == 'cifar10':
        train_dataset_download = datasets.CIFAR10(root=dataset_path, train=True, 
                                          download=True, transform=None)
        test_dataset_download = datasets.CIFAR10(root=dataset_path, train=False,
                                       download=True, transform=None)      
        train_transforms, basic_transforms = get_transforms('cifar')

        labels = np.array(train_dataset_download.targets) # used for stratified sampling
        labels_test = np.array(test_dataset_download.targets)

    elif dataset_name == 'cifar100':
        train_dataset_download = datasets.CIFAR100(root=dataset_path, train=True, 
                                         download=True, transform=None)
        test_dataset_download = datasets.CIFAR100(root=dataset_path, train=False,
                                                  download=True, transform=None) 
        train_transforms, basic_transforms = get_transforms('cifar')

        labels = np.array(train_dataset_download.targets) # used for stratified sampling
        labels_test = np.array(test_dataset_download.targets)

    elif dataset_name == 'svhn':
        train_dataset_download = datasets.SVHN(root=dataset_path, split="train",
                                               download=True, transform=None)
        extra_dataset_download = datasets.SVHN(root=dataset_path, split="extra",
                                               download=True, transform=None)
        test_dataset_download = datasets.SVHN(root=dataset_path, split="test",
                                              download=True, transform=None)
        
        # use 'train' and 'extra' split for training
        l1 = train_dataset_download.labels
        l2 = extra_dataset_download.labels
        labels = np.concatenate([l1,l2])
        labels_test = np.array(test_dataset_download.labels)

        train_dataset_download = ConcatDataset([train_dataset_download, extra_dataset_download])
        train_transforms, basic_transforms = get_transforms("svhn")


    else:
        raise NotImplementedError(f'no known dataset named {dataset_name}')
    

    if classes is not None:
        train_dataset_download, labels = filter_class_indices(train_dataset_download, classes, labels)
        test_dataset_download, labels_test = filter_class_indices(test_dataset_download, classes, labels_test)
        
    train_dataset = SimCLRDataset(train_dataset_download, 
                                train_transforms, basic_transforms,
                                augment_both_views=augment_both_views,
                                dataset_name=dataset_name)
    
    # Adjust for multi-GPU
    shuffle = not multi_gpu  # Ensures DistributedSampler handles shuffling
    effective_batch_size = batch_size // world_size if multi_gpu else batch_size
    drop_last = multi_gpu  # Avoids uneven batches in DDP

    sampler = DistributedSampler(train_dataset, num_replicas=world_size) if multi_gpu else None
    
    if supervision == 'SSL' or supervision == 'CL':
        train_dataloader = DataLoader(train_dataset, batch_size=effective_batch_size,
                                    shuffle=shuffle, num_workers=num_workers,
                                    pin_memory=True, drop_last=drop_last, 
                                    sampler=sampler)
    elif supervision == 'SCL':
        print("Using stratified sampling")
        # Approximate stratified sampling
        if multi_gpu:
            if dataset_name == 'svhn': # classes are not balanced
                sampler = DistributedStratifiedBatchSamplerSoftBalance(labels, effective_batch_size, 
                                                                    num_replicas=world_size,
                                                                    rank = torch.distributed.get_rank(),
                                                                    drop_last=True)
            else:
                sampler = DistributedStratifiedBatchSampler(labels, effective_batch_size,
                                                        num_replicas=world_size,
                                                        rank=torch.distributed.get_rank(),
                                                        drop_last=True)
        else:
            sampler = ApproxStratifiedSampler(labels, batch_size)
        train_dataloader = DataLoader(train_dataset, batch_sampler=sampler,
                                      num_workers=num_workers, pin_memory=True,
                                      shuffle=False)
        
        # train_dataloader = DataLoader(train_dataset, batch_size=effective_batch_size,
        #                             shuffle=shuffle, num_workers=num_workers,
        #                             pin_memory=True, drop_last=drop_last, 
        #                             sampler=sampler)
    if test is not None:
        test_dataset = SimCLRDataset(test_dataset_download,
                                     train_transforms, basic_transforms,
                                     augment_both_views=False,
                                     dataset_name=dataset_name)
        test_dataloader = DataLoader(test_dataset, batch_size=effective_batch_size,
                                     shuffle=True, num_workers=num_workers,
                                     pin_memory=True)
        return train_dataset, train_dataloader, test_dataset, test_dataloader, labels, labels_test
    
    return train_dataset, train_dataloader

def filter_class_indices(dataset, classes, labels):
    """
    Filter indices of a dataset for a subset of classes.
    """
    if labels is None:
        labels = np.array(dataset.targets)
    class_indices = np.where(np.isin(labels, classes))[0]
    labels = labels[class_indices]
    class_indices = list(map(int, class_indices))
    return Subset(dataset, class_indices), labels

class ApproxStratifiedSampler(Sampler):
    def __init__(self, labels, batch_size, num_batches=None):
        """
        labels: List or tensor of dataset labels
        batch_size: Number of samples per batch
        num_batches: Total batches (default: use full dataset)
        """
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.num_classes = len(np.unique(labels))
        self.indices = np.arange(len(labels))

        # Compute class weights (inverse of class frequency)
        class_counts = np.bincount(self.labels)
        class_weights = 1.0 / class_counts
        sample_weights = class_weights[self.labels]

        # Compute number of batches
        total_samples = num_batches * batch_size if num_batches else len(labels)
        self.num_batches = total_samples // batch_size

        # Weighted random sampling for rough balance
        self.probabilities = sample_weights / sample_weights.sum()

    def __iter__(self):
        """Yield batches with approximately balanced class distribution."""
        for _ in range(self.num_batches):
            batch_indices = np.random.choice(self.indices, size=self.batch_size, p=self.probabilities, replace=False)
            yield batch_indices.tolist()

    def __len__(self):
        return self.num_batches

class DistributedStratifiedBatchSampler(Sampler):
    def __init__(self, labels, batch_size, num_replicas=None, rank=None, drop_last=False):
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.drop_last = drop_last

        if num_replicas is None:
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            rank = torch.distributed.get_rank()

        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = len(self.labels)

        # Group samples by class
        self.class_to_indices = defaultdict(list)
        for idx, label in enumerate(self.labels):
            self.class_to_indices[label].append(idx)

        self.classes = list(self.class_to_indices.keys())

        # Compute total number of batches per replica
        total_batches = (self.num_samples // self.batch_size)
        self.num_batches_per_replica = total_batches // self.num_replicas
        self.total_batches = self.num_batches_per_replica * self.num_replicas

    def set_epoch(self, epoch):
        self.epoch = epoch

    def __iter__(self):
        rng = np.random.default_rng(seed=self.epoch)

        # Shuffle indices within each class
        shuffled_class_indices = {
            cls: rng.permutation(indices).tolist()
            for cls, indices in self.class_to_indices.items()
        }

        # Interleave across classes
        pooled_indices = []
        class_cursors = {cls: 0 for cls in self.classes}

        while len(pooled_indices) < self.total_batches * self.batch_size:
            for cls in self.classes:
                if class_cursors[cls] < len(shuffled_class_indices[cls]):
                    pooled_indices.append(shuffled_class_indices[cls][class_cursors[cls]])
                    class_cursors[cls] += 1
                    if len(pooled_indices) >= self.total_batches * self.batch_size:
                        break

        # Partition across replicas
        batches = [
            pooled_indices[i * self.batch_size : (i + 1) * self.batch_size]
            for i in range(self.total_batches)
        ]
        # Select only the portion for this rank
        replica_batches = batches[self.rank::self.num_replicas]

        for batch in replica_batches:
            yield batch

    def __len__(self):
        return self.num_batches_per_replica
    
class DistributedStratifiedOversamplingBatchSampler(Sampler):
    def __init__(self, labels, batch_size, num_replicas=None, rank=None, drop_last=False):
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.drop_last = drop_last

        if num_replicas is None:
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            rank = torch.distributed.get_rank()

        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        # Group samples by class
        self.class_to_indices = defaultdict(list)
        for idx, label in enumerate(self.labels):
            self.class_to_indices[label].append(idx)

        self.classes = list(self.class_to_indices.keys())
        self.max_class_size = max(len(idxs) for idxs in self.class_to_indices.values())
        self.num_samples = self.max_class_size * len(self.classes)

        # Total batches & per-replica batches
        total_batches = self.num_samples // self.batch_size
        self.num_batches_per_replica = total_batches // self.num_replicas
        self.total_batches = self.num_batches_per_replica * self.num_replicas

    def set_epoch(self, epoch):
        self.epoch = epoch

    def __iter__(self):
        rng = np.random.default_rng(seed=self.epoch)

        # Oversample to balance classes
        oversampled_class_indices = {}
        for cls, indices in self.class_to_indices.items():
            if len(indices) < self.max_class_size:
                sampled = rng.choice(indices, self.max_class_size, replace=True).tolist()
            else:
                sampled = rng.permutation(indices).tolist()
            oversampled_class_indices[cls] = sampled

        # Flatten: interleave class samples to maintain balance
        interleaved_indices = []
        for i in range(self.max_class_size):
            for cls in self.classes:
                interleaved_indices.append(oversampled_class_indices[cls][i])

        # Slice into batches
        batches = [
            interleaved_indices[i * self.batch_size: (i + 1) * self.batch_size]
            for i in range(self.total_batches)
        ]

        # Assign batches to replicas
        replica_batches = batches[self.rank::self.num_replicas]

        for batch in replica_batches:
            yield batch

    def __len__(self):
        return self.num_batches_per_replica
    
class DistributedStratifiedBatchSamplerSoftBalance(Sampler):
    def __init__(self, labels, batch_size, num_classes_per_batch=5, num_replicas=None, rank=None, drop_last=False):
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.num_classes_per_batch = num_classes_per_batch
        self.drop_last = drop_last

        if num_replicas is None:
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            rank = torch.distributed.get_rank()

        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0

        # Group samples by class
        self.class_to_indices = defaultdict(list)
        for idx, label in enumerate(self.labels):
            self.class_to_indices[label].append(idx)

        self.classes = list(self.class_to_indices.keys())

        # Estimate how many batches we can get
        est_total_batches = len(self.labels) // batch_size
        self.num_batches_per_replica = est_total_batches // self.num_replicas
        self.total_batches = self.num_batches_per_replica * self.num_replicas

    def set_epoch(self, epoch):
        self.epoch = epoch

    def __iter__(self):
        rng = np.random.default_rng(seed=self.epoch)

        # Shuffle indices within each class
        class_indices = {
            cls: rng.permutation(idxs).tolist()
            for cls, idxs in self.class_to_indices.items()
        }

        # Track cursor per class
        class_cursors = {cls: 0 for cls in self.classes}

        pooled_batches = []

        for _ in range(self.total_batches):
            # Sample subset of classes for this batch
            selected_classes = rng.choice(self.classes, size=self.num_classes_per_batch, replace=False)
            samples_per_class = self.batch_size // self.num_classes_per_batch
            batch = []

            for cls in selected_classes:
                idxs = class_indices[cls]
                cur = class_cursors[cls]

                # Replenish if exhausted
                if cur + samples_per_class > len(idxs):
                    idxs = rng.permutation(self.class_to_indices[cls]).tolist()
                    class_indices[cls] = idxs
                    cur = 0

                batch.extend(idxs[cur:cur + samples_per_class])
                class_cursors[cls] = cur + samples_per_class

            pooled_batches.append(batch)

        # Shard across DDP replicas
        replica_batches = pooled_batches[self.rank::self.num_replicas]

        for batch in replica_batches:
            yield batch

    def __len__(self):
        return self.num_batches_per_replica