import numpy as np
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision
from torchvision import datasets, transforms

def get_dataset(name, img_size, batch_sizes=(768, 256, 1), imagenet_dir=None):
    if name == "CIFAR10":
        # commonly used: mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616)
        transform_train = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(img_size, padding=4),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomRotation(degrees=15),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))
        ])

        transform_test = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))
        ])

        train_dataset = torchvision.datasets.CIFAR10(
            root='./data/CIFAR10/',
            train=True,
            download=True,
            transform=transform_train
        )
        test_dataset = torchvision.datasets.CIFAR10(
            root='./data/CIFAR10/',
            train=False,
            download=True,
            transform=transform_test
        )

    elif name == "CIFAR100":
        # commonly used: mean=(0.5071, 0.4867, 0.4408), std=(0.2673, 0.2564, 0.2762)
        transform_train = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(img_size, padding=4),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomRotation(degrees=15),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2673, 0.2564, 0.2762))
        ])

        transform_test = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2673, 0.2564, 0.2762))
        ])

        train_dataset = torchvision.datasets.CIFAR100(
            root='./data/CIFAR100/',
            train=True,
            download=True,
            transform=transform_train
        )
        test_dataset = torchvision.datasets.CIFAR100(
            root='./data/CIFAR100/',
            train=False,
            download=True,
            transform=transform_test
        )

    elif name == "ImageNet":
        if imagenet_dir is None:
            raise ValueError("ImageNet (imagenet_dir)")

        transform_train = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(img_size, padding=4),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomRotation(degrees=15),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        transform_test = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        train_dataset = datasets.ImageFolder(
            root=f"{imagenet_dir}/train",
            transform=transform_train
        )
        test_dataset = datasets.ImageFolder(
            root=f"{imagenet_dir}/val",
            transform=transform_test
        )

    else:
        raise ValueError("Not Supported Dataset")

    # Train/Validation Split
    num_train = len(train_dataset)
    indices = list(range(num_train))
    np.random.shuffle(indices)

    split = int(np.floor(0.1 * num_train))
    valid_idx, train_idx = indices[:split], indices[split:]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_sizes[0],
        sampler=train_sampler,
        num_workers=4
    )
    valid_loader = DataLoader(
        train_dataset,
        batch_size=batch_sizes[1],
        sampler=valid_sampler,
        num_workers=4
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_sizes[2],
        shuffle=False,
        num_workers=4
    )

    return train_loader, valid_loader, test_loader


def get_label_filtered_loader(dataset, labels, batch_sizes):

    indices = [i for i, (_, target) in enumerate(dataset) if target in labels]
    sampler = SubsetRandomSampler(indices)
    loader = DataLoader(
        dataset,
        batch_size=batch_sizes[2],
        sampler=sampler,
        num_workers=4
    )
    return loader


def get_label_loader(imagenet_dir, img_size, num_samples, batch_size=1):

    transform_train = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(img_size, padding=4),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(degrees=15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = datasets.ImageFolder(
        root=f"{imagenet_dir}/train",
        transform=transform_train
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    np.random.shuffle(indices)

    subset_indices = indices[:num_samples]
    subset_sampler = SubsetRandomSampler(subset_indices)

    loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=subset_sampler,
        num_workers=4
    )

    return loader


