
import numpy as np
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler

def upsample(cluster_labels):

    # get the number of samples in each cluster
    cluster_sizes = np.bincount(cluster_labels)

    # get the number of samples in the largest cluster
    max_cluster_size = np.max(cluster_sizes)

    # upsample small clusters
    upsampled_indices = []
    upsampled_clusters = []
    for i in range(len(cluster_sizes)):
        cluster_indices = np.where(cluster_labels == i)[0]
        if cluster_sizes[i] < max_cluster_size:
            # upsample this cluster by randomly sampling from the cluster            
            upsampled_indices.append(np.random.choice(cluster_indices, size=max_cluster_size, replace=True))
            upsampled_clusters.append(np.full(max_cluster_size, i))
        else:
            # add this cluster to the new dataset
            upsampled_indices.append(cluster_indices)
            upsampled_clusters.append(np.full(cluster_sizes[i], i))
    upsampled_indices = np.concatenate(upsampled_indices)
    upsampled_clusters = np.concatenate(upsampled_clusters)

    # return the new dataset with the upsampled data
    return upsampled_indices, upsampled_clusters


def upsample_by_factor(dataset, cluster_labels, args):
    # upsample the small cluster of each class by args.upsample_factor

    # upsample small clusters
    upsampled_indices = []
    upsampled_clusters = []
    
    for i in range(args.num_classes):
        class_indices = np.where(np.array(dataset.targets) == i)[0]
        class_cluster_labels = cluster_labels[class_indices]
        class_cluster_sizes = np.bincount(class_cluster_labels)
        args.logger.info(f"Class {i} cluster sizes: {class_cluster_sizes}")

        max_cluster_size = np.max(class_cluster_sizes)

        for j in range(len(class_cluster_sizes)):
            cluster_indices = np.where(class_cluster_labels == j)[0]

            if len(cluster_indices) == 0:
                continue

            if class_cluster_sizes[j] < max_cluster_size:
                args.logger.info(f"Upsampling cluster {j} of class {i} by {args.upsample_factor}x")

                sample_size = class_cluster_sizes[j] * args.upsample_factor

                # upsample this cluster by randomly sampling from the cluster            
                upsampled_indices.append(np.random.choice(class_indices[cluster_indices], size=sample_size, replace=True))
                upsampled_clusters.append(np.full(sample_size, j))
            else:
                # add this cluster to the new dataset
                upsampled_indices.append(cluster_indices)
                upsampled_clusters.append(np.full(class_cluster_sizes[j], j))
        
    upsampled_indices = np.concatenate(upsampled_indices)
    upsampled_clusters = np.concatenate(upsampled_clusters)

    # return the new dataset with the upsampled data
    return upsampled_indices, upsampled_clusters


def downsample(cluster_labels):

    # get the number of samples in each cluster
    cluster_sizes = np.bincount(cluster_labels)

    # get the number of samples in the smallest cluster
    min_cluster_size = np.min(cluster_sizes)

    # downsample large clusters
    downsampled_indices = []
    downsampled_clusters = []
    for i in range(len(cluster_sizes)):
        cluster_indices = np.where(cluster_labels == i)[0]
        if cluster_sizes[i] > min_cluster_size:
            # downsample this cluster by randomly sampling from the cluster
            downsampled_indices.append(np.random.choice(cluster_indices, size=min_cluster_size, replace=False))
            downsampled_clusters.append(np.full(min_cluster_size, i))
        else:
            # add this cluster to the new dataset
            downsampled_indices.append(cluster_indices)
            downsampled_clusters.append(np.full(cluster_sizes[i], i))
    downsampled_indices = np.concatenate(downsampled_indices)
    downsampled_clusters = np.concatenate(downsampled_clusters)

    # return the new dataset with the downsampled data
    return downsampled_indices, downsampled_clusters


def sample_by_cluster(dataset, cluster_labels, args):
    # get the number of samples in each cluster
    cluster_sizes = np.bincount(cluster_labels)

    if args.upsample_by_cluster_size:
        args.logger.info(f'Upsampling by cluster size')
        group_sizes_ = np.array([np.sum(cluster_sizes[i:i+(args.num_groups//args.num_classes-1)]) for i in range(0, len(cluster_sizes), args.num_groups//args.num_classes)])
        group_sizes_ = np.repeat(group_sizes_, args.num_groups//args.num_classes)
        if args.dataset == 'waterbirds':
            args.logger.info(f'Adaptive sample power: sqrt({group_sizes_} / {cluster_sizes})')
            args.sample_by_cluster_power = np.sqrt(group_sizes_ / cluster_sizes)
        else:
            args.logger.info(f'Adaptive sample power: {group_sizes_} / {cluster_sizes}')
            args.sample_by_cluster_power = group_sizes_ / cluster_sizes
        args.logger.info(f'Adaptive sample power: {args.sample_by_cluster_power}')
    elif args.sample_by_silhouette:
        assert args.silhouette 
        args.sample_by_cluster_power = 1/ args.silhouette_score ** args.sample_by_cluster_power
        args.logger.info(f'Sampling power by silhouette score: {args.sample_by_cluster_power}')

    if args.adaptive_sample_power:
        # compute the sample probabilities for each cluster
        cluster_sample_probs = 1 / cluster_sizes * args.sample_by_cluster_power
    else:
        # compute the sample probabilities for each cluster
        cluster_sample_probs = 1 / cluster_sizes ** args.sample_by_cluster_power

    args.logger.info(f'Cluster sample probabilities: {cluster_sample_probs}')

    # compute the sample probabilities for each sample
    sample_probs = np.zeros(len(dataset))
    for i in range(len(cluster_sizes)):
        sample_probs[cluster_labels == i] = cluster_sample_probs[i]

    return sample_probs


def uniform_sample(cluster_labels):

    # get cluster counts
    cluster_counts = np.bincount(cluster_labels)

    # get cluster weights
    cluster_weights = 1.0/cluster_counts
    # get sample weights
    sample_weights = cluster_weights[cluster_labels]
    # get sampler
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

    return sampler


def sample_train_data(args, train_dataset, cluster_labels):
    # upsample/dowmsample to balance groups
    if args.sample != 'none':
        if args.sample == 'upsample':
            sampled_indices, sampled_clusters = upsample(cluster_labels)
        elif args.sample == 'upsample_by_factor':
            sampled_indices, sampled_clusters = upsample_by_factor(train_dataset, cluster_labels, args)
        elif args.sample == 'downsample':
            sampled_indices, sampled_clusters = downsample(cluster_labels)

        train_dataset = Subset(train_dataset, sampled_indices)

        # log cluster sizes
        cluster_sizes = np.bincount(sampled_clusters)
        args.logger.info(f'Sampled cluster sizes: {cluster_sizes}')

        # log sampled true group sizes
        true_group_sizes = np.bincount(np.array(train_dataset.dataset.group_array)[sampled_indices])
        args.logger.info(f'Sampled true group sizes: {true_group_sizes}')
    else:
        sampled_clusters = cluster_labels

    # sample by cluster if specified
    if args.sample_by_cluster:
        args.logger.info('Sampling by cluster...')
        sample_probs = sample_by_cluster(train_dataset, cluster_labels, args)

        if args.weighted_sampler:
            # normalize sample probabilities by class label
            for i in range(args.num_classes):
                class_indices = np.where(np.array(train_dataset.targets) == i)[0]
                sample_probs[class_indices] /= np.sum(sample_probs[class_indices])

            sampler = WeightedRandomSampler(sample_probs, len(sample_probs))
        else:
            # for each class, sample the same number of data by the computed probabilities
            sampled_indices = []
            size_per_class = args.train_size // args.num_classes
            for i in range(args.num_classes):
                args.logger.info('Sampling {} samples from class {}...'.format(size_per_class, i))
                class_indices = np.where(np.array(train_dataset.targets) == i)[0]
                class_probs = sample_probs[class_indices]/np.sum(sample_probs[class_indices])
                sampled_indices.append(np.random.choice(class_indices, size=size_per_class, p=class_probs, replace=True))

            # concatenate the sampled indices
            sampled_indices = np.concatenate(sampled_indices)

            if len(sampled_indices) < args.train_size:
                i = np.random.randint(args.num_classes)
                args.logger.info('Sampling {} more samples from class {}...'.format(args.train_size - len(sampled_indices), i))
                class_indices = np.where(np.array(train_dataset.targets) == i)[0]
                sampled_indices = np.concatenate([sampled_indices, np.random.choice(class_indices, size=args.train_size - len(sampled_indices), replace=True)])

            sampled_clusters = cluster_labels[sampled_indices]

            train_dataset = Subset(train_dataset, sampled_indices)

            # log cluster sizes
            cluster_sizes = np.bincount(sampled_clusters)
            args.logger.info(f'Sampled cluster sizes: {cluster_sizes}')

            # log sampled true group sizes
            true_group_sizes = np.bincount(np.array(train_dataset.dataset.group_array)[sampled_indices])
            args.logger.info(f'Sampled true group sizes: {true_group_sizes}')

    if args.uniform_group_sampler:
        sampler = uniform_sample(sampled_clusters)
    elif args.uniform_class_sampler:
        sampler = uniform_sample(np.array(train_dataset.targets))

    if args.uniform_group_sampler or args.weighted_sampler:
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers)
    else:
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

    return train_dataset, train_loader