import torch
import warnings
import os
import numpy as np
# from sklearn.model_selection import train_test_split
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder, voc, Flowers102, Food101, Caltech101, Caltech256,\
    StanfordCars, EuroSAT

# from functions.caltech import Caltech101, Caltech256
# from functions.aircraft import Aircraft

from .cub200 import Cub2011
from .fewshot import find_fewshot_indices
from .dtd import DTD
from .aircraft import FGVCAircraft
from .dogs import StanfordDogs
from .sun397 import SUN397

warnings.filterwarnings('ignore')


class TransformedDataset(Dataset):
    def __init__(self, ds, transform=None, normalize=None):
        self.transform = transform
        self.ds = ds
        self.normalize = normalize
        self.classes = ds.classes if hasattr(ds, 'classes') else None

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        sample, label = self.ds[idx]
        if self.transform:
            sample = self.transform(sample)
        if sample.shape[0] == 1:
            sample = sample.repeat(3, 1, 1)
        if self.normalize:
            sample = self.normalize(sample)
        return sample, label


class Gray2RGB(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return


def build_cifar(cutout=False, use_cifar10=True, download=False):
    aug = [T.RandomCrop(32, padding=4)]
    aug += [T.RandomHorizontalFlip(), T.ToTensor()]

    if cutout:
        aug.append(cutout(n_holes=1, length=16))

    if use_cifar10:
        Normalize = T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        aug.append(Normalize)
        transform_train = T.Compose(aug)
        transform_test = T.Compose([T.ToTensor(), Normalize])
        train_dataset = CIFAR10(root='/mnt/lustrenew2/liyuhang1/data/cifar10',
                                train=True, download=download, transform=transform_train)
        val_dataset = CIFAR10(root='/mnt/lustrenew2/liyuhang1/data/cifar10',
                              train=False, download=download, transform=transform_test)

    else:
        Normalize = T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        aug.append(Normalize)
        transform_train = T.Compose(aug)
        transform_test = T.Compose([T.ToTensor(), Normalize])
        train_dataset = CIFAR100(root='./raw/',
                                 train=True, download=download, transform=transform_train)
        val_dataset = CIFAR100(root='./raw/',
                               train=False, download=download, transform=transform_test)

    return train_dataset, val_dataset


def build_transfer_dataset(dataset='cifar10', imgnet_norm='normal', download=False, fuse=0, fewshot=None,
                           data_config=None):
    aug = [T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.ToTensor()]
    aug_test = [T.Resize(256), T.CenterCrop(224), T.ToTensor()]

    if imgnet_norm == 'normal':
        Normalize = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    elif imgnet_norm == 'bit':
        Normalize = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    else:
        Normalize = T.Normalize((0., 0., 0.), (1., 1., 1.))
    if dataset not in ['caltech101', 'octmnist']:
        aug.append(Normalize)
        aug_test.append(Normalize)
    transform_train = T.Compose(aug)
    transform_test = T.Compose(aug_test)
    if data_config is not None:
        import timm
        transform_train = timm.data.create_transform(**data_config, is_training=True)
        transform_test = timm.data.create_transform(**data_config, is_training=False)

    if dataset == 'cifar10':
        train_dataset = CIFAR10(root='../cifar10/data',
                                train=True, download=download, transform=transform_train)
        val_dataset = CIFAR10(root='../cifar10/data',
                              train=False, download=download, transform=transform_test)
        num_cls = 10
    elif dataset == 'sun397':
        train_dataset = SUN397(root='../sun397/data/', split='train', download=True, transform=transform_train)
        val_dataset = SUN397(root='../sun397/data/', split='test', transform=transform_test)
        num_cls = 397

    elif dataset == 'cifar100':
        train_dataset = CIFAR100(root='/mnt/lustre/liyuhang1/data/cifar100',
                                 train=True, download=download, transform=transform_train)
        val_dataset = CIFAR100(root='/mnt/lustre/liyuhang1/data/cifar100',
                               train=False, download=download, transform=transform_test)
        num_cls = 100

    elif dataset == 'caltech101':
        ds = Caltech101('../caltech101/data', download=True)
        NUM_TRAINING_SAMPLES_PER_CLASS = 30
        class_start_idx = [0] + [i for i in np.arange(1, len(ds)) if ds.y[i] == ds.y[i - 1] + 1]
        train_indices = sum([np.arange(start_idx, start_idx + NUM_TRAINING_SAMPLES_PER_CLASS).tolist() for start_idx in
                             class_start_idx], [])
        test_indices = list((set(np.arange(1, len(ds))) - set(train_indices)))
        train_set = Subset(ds, train_indices)
        test_set = Subset(ds, test_indices)

        train_dataset = TransformedDataset(train_set, transform=transform_train, normalize=Normalize)
        val_dataset = TransformedDataset(test_set, transform=transform_test, normalize=Normalize)
        num_cls = 101
    elif dataset == 'flowers':
        train_dataset = ImageFolder('../flowers/data/flowers-102/new/train', transform=transform_train)
        val_dataset = ImageFolder('../flowers/data/flowers-102/new/test', transform=transform_test)
        num_cls = 102
    elif dataset == 'pets':
        train_dataset = ImageFolder('../pets/data/oxford-pets/train', transform=transform_train)
        val_dataset = ImageFolder('../pets/data/oxford-pets//test/', transform=transform_test)
        num_cls = 37
    elif dataset == 'aircraft':
        train_dataset = FGVCAircraft('../aircraft/data/fgvc-aircraft-2013b', train=True, transform=transform_train)
        val_dataset = FGVCAircraft('../aircraft/data/fgvc-aircraft-2013b', train=False, transform=transform_test)
        num_cls = 100
    elif dataset == 'cars':
        train_dataset = StanfordCars('../cars/data', transform=transform_train)
        val_dataset = StanfordCars('../cars/data', split="test", transform=transform_test)
        num_cls = 196
    elif dataset == 'foods':
        train_dataset = Food101(root='../foods/data', split='train', transform=transform_train)
        val_dataset = Food101(root='../foods/data', split='test', transform=transform_test)
        num_cls = 101
    elif dataset == 'cub':
        train_dataset = Cub2011('../cub/data', train=True, transform=transform_train)
        val_dataset = Cub2011('../cub/data', train=False, transform=transform_test)
        num_cls = 200
    elif dataset == 'pascalvoc':
        train_dataset = voc.VOCDetection('/mnt/lustre/liyuhang1/data/', year='2012', image_set='train',
                                         download=False, transform=transform_train, target_transform=encode_labels)
        val_dataset = voc.VOCDetection('/mnt/lustre/liyuhang1/data/', year='2012', image_set='val',
                                       download=False, transform=transform_train, target_transform=encode_labels)
        num_cls = 20
    elif dataset == 'dtd':
        train_dataset = DTD(root='../dtd/data/dtd/dtd', train=True, transform=transform_train)
        val_dataset = DTD(root='../dtd/data/dtd/dtd', train=False, transform=transform_test)
        num_cls = 47
    elif dataset == 'dogs':
        train_dataset = StanfordDogs(root='../dogs/data/', train=True, transform=transform_train)
        val_dataset = StanfordDogs(root='../dogs/data/', train=False, transform=transform_test)
        num_cls = 120
    elif dataset == 'eurosat':
        dset = EuroSAT(root='../eurosat/data/')
        train_idx, validation_idx = train_test_split(np.arange(len(dset)),
                                                     test_size=0.8,
                                                     random_state=1000,
                                                     shuffle=True,
                                                     stratify=dset.targets)
        train_dataset = Subset(dset, train_idx); train_dataset.dataset.transform = transform_train
        val_dataset = Subset(dset, validation_idx); val_dataset.dataset.transform = transform_test
        num_cls = 10
    else:
        raise NotImplementedError

    if fewshot is not None:
        indices = find_fewshot_indices(train_dataset, fewshot, num_cls)
        train_dataset = torch.utils.data.Subset(train_dataset, indices=indices)

    if isinstance(fuse, list):
        if 0 in fuse:
            fuse.remove(0)
            add = True
        else:
            add = False
        # if dataset == 'flowers':
        #     fuse = [100, 200, 300, 400]
        datasets = [ImageFolder(root='../{}/data/gen{}'.format(dataset, f), transform=transform_train) for f in fuse]
        if add:
            datasets += [train_dataset]
        train_dataset = torch.utils.data.ConcatDataset(datasets)
    elif fuse > 0:
        train_dataset = ImageFolder(root='../{}/data/gen{}'.format(dataset, fuse), transform=transform_train)
    return train_dataset, val_dataset, num_cls


def build_imagenet():
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # root = '/data/dsk/ImageNet_dataset/ImageNet/ImageNet' # 4031.44
    root = '/data_smr/dataset/ImageNet'
    train_root = os.path.join(root,'train')
    val_root = os.path.join(root,'val')
    train_dataset = ImageFolder(
        train_root,
        T.Compose([
            T.RandomResizedCrop(224),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            normalize,
        ])
    )
    val_dataset = ImageFolder(
        val_root,
        T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            normalize,
        ])
    )
    return train_dataset, val_dataset
