from __future__ import print_function

import os
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from PIL import Image
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import RandomSampler, SequentialSampler

"""
mean = {
    'cifar100': (0.5071, 0.4867, 0.4408),
}

std = {
    'cifar100': (0.2675, 0.2565, 0.2761),
}
"""


def get_data_folder():
    data_folder = './data/'

    if not os.path.isdir(data_folder):
        os.makedirs(data_folder)

    return data_folder


class CIFAR100Instance(datasets.CIFAR100):
    """CIFAR100Instance Dataset.
    """
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        # return imgs_extra, target_out, index
        return img, target, index
    

def get_cifar100_dataloaders_default(batch_size=128, num_workers=8, is_instance=False, opt=None, mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)):
    data_folder = get_data_folder()
    
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),])
    
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),])

    if is_instance:
        train_set = CIFAR100Instance(root=data_folder, download=True, train=True, transform=train_transform)
        n_data = len(train_set)
    else:
        train_set = datasets.CIFAR100(root=data_folder, download=True, train=True, transform=train_transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    test_set = datasets.CIFAR100(root=data_folder, download=True, train=False, transform=test_transform)
    test_loader = DataLoader(test_set, batch_size=int(batch_size/2), shuffle=False, num_workers=int(num_workers/2))

    if is_instance:
        return train_loader, test_loader, n_data
    else:
        return train_loader, test_loader


def get_cifar100_dataloaders(batch_size=128, num_workers=8, is_instance=False, opt=None, mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)):
    data_folder = get_data_folder()
    
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),])

    train_original_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),])
    
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),])

    test_original_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),])
    
    if is_instance:
        train_set = CIFAR100Instance(root=data_folder, download=True, train=True, transform=train_transform)
        train_original_set = CIFAR100Instance(root=data_folder, download=True, train=True, transform=train_original_transform)
        val_set = CIFAR100Instance(root=data_folder, download=True, train=True, transform=val_transform)
        n_data = len(train_set)
    else:
        train_set = datasets.CIFAR100(root=data_folder, download=True, train=True, transform=train_transform)
        train_original_set = datasets.CIFAR100(root=data_folder, download=True, train=True, transform=train_original_transform)
        val_set = datasets.CIFAR100(root=data_folder, download=True, train=True, transform=val_transform)
    test_set = datasets.CIFAR100(root=data_folder, download=True, train=False, transform=test_transform)
    test_original_set = datasets.CIFAR100(root=data_folder, download=True, train=False, transform=test_original_transform)

    if (hasattr(opt, 'distributed') and opt.distributed):
        train_sampler = DistributedSampler(train_set)
        train_original_sampler = DistributedSampler(train_original_set)
        val_sampler = DistributedSampler(val_set, shuffle=False)
        test_sampler = DistributedSampler(test_set, shuffle=False)
        test_original_sampler = DistributedSampler(test_set, shuffle=False)
    else:
        train_sampler = RandomSampler(train_set)
        train_original_sampler = RandomSampler(train_original_set)
        val_sampler = RandomSampler(val_set)
        test_sampler = SequentialSampler(test_set)
        test_original_sampler = SequentialSampler(test_set)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=num_workers)
    train_original_loader = DataLoader(train_original_set, batch_size=batch_size, shuffle=(train_original_sampler is None), sampler=train_original_sampler, num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, sampler=val_sampler, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size//2, shuffle=False, sampler=test_sampler, num_workers=num_workers//2)
    test_original_loader = DataLoader(test_set, batch_size=batch_size//2, shuffle=False, sampler=test_sampler, num_workers=num_workers//2)

    if is_instance:
        return train_set, train_original_set, val_set, test_set, test_original_set, train_loader, train_original_loader, val_loader, test_loader, test_original_loader, n_data
    else:
        return train_set, train_original_set, val_set, test_set, test_original_set, train_loader, train_original_loader, val_loader, test_loader, test_original_loader


class CIFAR100InstanceSample(datasets.CIFAR100):
    """
    CIFAR100Instance+Sample Dataset
    """
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, k=4096, mode='exact', is_sample=True, percent=1.0):
        super().__init__(root=root, train=train, download=download, transform=transform, target_transform=target_transform)
        self.k = k
        self.mode = mode
        self.is_sample = is_sample
 
        num_classes = 100
        num_samples = len(self.data)
        label = self.targets
 
        print("==> CRD default")
 
        self.cls_positive = [[] for i in range(num_classes)]
        for i in range(num_samples):
            self.cls_positive[label[i]].append(i)
 
        self.cls_negative = [[] for i in range(num_classes)]
        for i in range(num_classes):
            for j in range(num_classes):
                if j == i:
                    continue
                self.cls_negative[i].extend(self.cls_positive[j])
 
        self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
        self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]
 
        if 0 < percent < 1:
            n = int(len(self.cls_negative[0]) * percent)
            self.cls_negative = [np.random.permutation(self.cls_negative[i])[0:n] for i in range(num_classes)]
 
        self.cls_positive = np.asarray(self.cls_positive)
        self.cls_negative = np.asarray(self.cls_negative)

    def __getitem__(self, index): 
        # New version of torchvision datasets only has .data, .target, not train_data and test_data
        img, target = self.data[index], self.targets[index]
 
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)
 
        if self.transform is not None:
            img = self.transform(img)
 
        if self.target_transform is not None:
            target = self.target_transform(target)
 
        if not self.is_sample:
            # directly return
            return img, target, index
        else:
            # sample contrastive examples
            if self.mode == 'exact':
                pos_idx = index
            elif self.mode == 'relax':
                pos_idx = np.random.choice(self.cls_positive[target], 1)
                pos_idx = pos_idx[0]
            else:
                raise NotImplementedError(self.mode)
            replace = True if self.k > len(self.cls_negative[target]) else False
            neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
            sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
            return img, target, index, sample_idx


def get_cifar100_dataloaders_sample(batch_size=128, num_workers=8, k=4096, mode='exact', is_sample=True, percent=1.0, mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)):
    """
    cifar 100
    """
    data_folder = get_data_folder()

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),])

    train_set = CIFAR100InstanceSample(root=data_folder, download=True, train=True,transform=train_transform, k=k, mode=mode, is_sample=is_sample, percent=percent)
    n_data = len(train_set)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    test_set = datasets.CIFAR100(root=data_folder, download=True, train=False, transform=test_transform)
    test_loader = DataLoader(test_set, batch_size=int(batch_size/2), shuffle=False, num_workers=int(num_workers/2))

    return train_loader, test_loader, n_data


def get_cifar100_dataloaders_224(batch_size=128, num_workers=8, is_instance=False, opt=None, mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)):
    data_folder = get_data_folder()
    
    # 添加Resize(224)到所有transform流程，并调整RandomCrop参数
    train_transform = transforms.Compose([
        transforms.Resize(224),                  # 新增上采样步骤
        transforms.RandomCrop(224, padding=4),   # 调整裁剪尺寸为224
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),])

    train_original_transform = transforms.Compose([
        transforms.Resize(224),                  # 新增上采样步骤
        transforms.RandomCrop(224, padding=4),   # 调整裁剪尺寸为224
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)),])
    
    # 验证和测试集添加Resize(224)并移除裁剪（保持224x224）
    val_transform = transforms.Compose([
        transforms.Resize(224),                  # 新增上采样步骤
        transforms.ToTensor(),
        transforms.Normalize(mean, std),])

    test_transform = transforms.Compose([
        transforms.Resize(224),                  # 新增上采样步骤
        transforms.ToTensor(),
        transforms.Normalize(mean, std),])

    if is_instance:
        train_set = CIFAR100Instance(root=data_folder, download=True, train=True, transform=train_transform)
        train_original_set = CIFAR100Instance(root=data_folder, download=True, train=True, transform=train_original_transform)
        val_set = CIFAR100Instance(root=data_folder, download=True, train=True, transform=val_transform)
        n_data = len(train_set)
    else:
        train_set = datasets.CIFAR100(root=data_folder, download=True, train=True, transform=train_transform)
        train_original_set = datasets.CIFAR100(root=data_folder, download=True, train=True, transform=train_original_transform)
        val_set = datasets.CIFAR100(root=data_folder, download=True, train=True, transform=val_transform)
    test_set = datasets.CIFAR100(root=data_folder, download=True, train=False, transform=test_transform)

    if (hasattr(opt, 'distributed') and opt.distributed):
        train_sampler = DistributedSampler(train_set)
        train_original_sampler = DistributedSampler(train_original_set)
        val_sampler = DistributedSampler(val_set, shuffle=False)
        test_sampler = DistributedSampler(test_set, shuffle=False)
    else:
        train_sampler = RandomSampler(train_set)
        train_original_sampler = RandomSampler(train_original_set)
        val_sampler = RandomSampler(val_set)
        test_sampler = SequentialSampler(test_set)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=num_workers)
    train_original_loader = DataLoader(train_original_set, batch_size=batch_size, shuffle=(train_original_sampler is None), sampler=train_original_sampler, num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, sampler=val_sampler, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size // 2, shuffle=False, sampler=test_sampler, num_workers=num_workers//2)

    if is_instance:
        return train_set, train_original_set, val_set, test_set, train_loader, train_original_loader, val_loader, test_loader, n_data
    else:
        return train_set, train_original_set, val_set, test_set, train_loader, train_original_loader, val_loader, test_loader