import os
import json
from collections import OrderedDict
import numpy as np
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split

from .dataset_lmdb import COOPLMDBDataset
from .abide import ABIDE

from .const import GTSRB_LABEL_MAP, IMAGENETNORMALIZE
from torch.utils.data import DataLoader, Subset, ConcatDataset, Dataset
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, GTSRB, Food101, SUN397, EuroSAT, UCF101, StanfordCars, Flowers102, DTD, OxfordIIITPet, MNIST, ImageNet, ImageFolder
from PIL import Image

def refine_classnames(class_names):
    for i, class_name in enumerate(class_names):
        class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ')
    return class_names


def get_class_names_from_split(root):
    with open(os.path.join(root, "split.json")) as f:
        split = json.load(f)["test"]
    idx_to_class = OrderedDict(sorted({s[-2]: s[-1] for s in split}.items()))
    return list(idx_to_class.values())



            # train_transform = transforms.Compose([
            #     transforms.Resize((int(args.input_size*9/8), int(args.input_size*9/8))),
            #     transforms.RandomCrop(args.input_size),
            #     # transforms.RandomResizedCrop(args.input_size),
            #     transforms.RandomHorizontalFlip(),
            #     transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
            #     transforms.ToTensor(),
            # ])
            # test_transform = transforms.Compose([
            #     transforms.Resize((args.input_size, args.input_size)),
            #     transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
            #     transforms.ToTensor(),
            # ])

def prepare_expansive_data(dataset, data_path,resize=None,dataAug=False):
    data_path = os.path.join(data_path, dataset)
    if dataset == "cifar10":
        if resize==None:
            preprocess = transforms.Compose([
                transforms.ToTensor(),
            ])
            preprocess_test=preprocess
        else:
            preprocess_test = transforms.Compose([
                transforms.Resize(resize),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            if dataAug is False:
                preprocess = transforms.Compose([
                transforms.Resize(resize),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            else:
                preprocess = transforms.Compose([
                transforms.Resize((int(resize*9/8), int(resize*9/8))),
                transforms.RandomCrop(resize),
                transforms.RandomHorizontalFlip(),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])

        train_data = datasets.CIFAR10(root = data_path, train = True, download = True, transform = preprocess)
        test_data = datasets.CIFAR10(root = data_path, train = False, download = True, transform = preprocess_test)
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=2),
        }
        configs = {
            'class_names': refine_classnames(test_data.classes),
            'mask': np.zeros((resize,resize)),
        }
    elif dataset == "cifar100":
        if resize==None:
            preprocess = transforms.Compose([
                transforms.ToTensor(),
            ])
            preprocess_test=preprocess
        else:
            preprocess_test = transforms.Compose([
                transforms.Resize(resize),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            if dataAug is False:
                preprocess = transforms.Compose([
                transforms.Resize(resize),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            else:
                preprocess = transforms.Compose([
                transforms.Resize((int(resize*9/8), int(resize*9/8))),
                transforms.RandomCrop(resize),
                transforms.RandomHorizontalFlip(),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])

        train_data = datasets.CIFAR100(root = data_path, train = True, download = True, transform = preprocess)
        test_data = datasets.CIFAR100(root = data_path, train = False, download = True, transform = preprocess_test)
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=8),
        }
        configs = {
            'class_names': refine_classnames(test_data.classes),
            'mask': np.zeros((resize,resize)),
        }
    elif dataset == "gtsrb":
        if resize==None:
            preprocess = transforms.Compose([
                transforms.ToTensor(),
            ])
            preprocess_test=preprocess
        else:
            preprocess_test = transforms.Compose([
                transforms.Resize((resize,resize)),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            if dataAug is False:
                preprocess = transforms.Compose([
                transforms.Resize((resize,resize)),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            else:
                preprocess = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.Resize((resize,resize)),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
        train_data = datasets.GTSRB(root = data_path, split="train", download = True, transform = preprocess)
        test_data = datasets.GTSRB(root = data_path, split="test", download = True, transform = preprocess_test)
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=2),
        }
        configs = {
            'class_names': refine_classnames(list(GTSRB_LABEL_MAP.values())),
            'mask': np.zeros((resize,resize)),
        }
    elif dataset == "svhn":
        preprocess = transforms.Compose([
            transforms.ToTensor(),
        ])
        train_data = datasets.SVHN(root = data_path, split="train", download = True, transform = preprocess)
        test_data = datasets.SVHN(root = data_path, split="test", download = True, transform = preprocess)
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=2),
        }
        configs = {
            'class_names': [f'{i}' for i in range(10)],
            'mask': np.zeros((32, 32)),
        }
    elif dataset == "abide":
        preprocess = transforms.ToTensor()
        D = ABIDE(root = data_path)
        X_train, X_test, y_train, y_test = train_test_split(D.data, D.targets, test_size=0.1, stratify=D.targets, random_state=1)
        train_data = ABIDE(root = data_path, transform = preprocess)
        train_data.data = X_train
        train_data.targets = y_train
        test_data = ABIDE(root = data_path, transform = preprocess)
        test_data.data = X_test
        test_data.targets = y_test
        loaders = {
            'train': DataLoader(train_data, 64, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, 64, shuffle = False, num_workers=2),
        }
        configs = {
            'class_names': ["non ASD", "ASD"],
            'mask': D.get_mask(),
        }
    elif dataset in ["food101", "eurosat", "sun397", "ucf101", "stanfordcars", "flowers102"]:
        if resize==None:
            preprocess = transforms.Compose([
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
            ])
            preprocess_test=preprocess
        else:
            preprocess_test = transforms.Compose([
                transforms.Resize((256,256)),
                transforms.CenterCrop(resize),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            if dataAug is False:
                preprocess = transforms.Compose([
                transforms.Resize((resize,resize)),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            else:
                preprocess = transforms.Compose([
                #transforms.Resize((256,256)),
                transforms.RandomResizedCrop((224,224)),
                transforms.RandomHorizontalFlip(),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.Resize((resize,resize)),
                transforms.ToTensor(),
                ])
        if dataset=='food101':
            train_data=Food101(data_path, split = 'train', transform=preprocess, download=True)
            test_data=Food101(data_path, split = 'test', transform=preprocess_test, download=True)
        elif dataset=='stanfordcars':
            import torchvision
            train_data=torchvision.datasets.StanfordCars("./",split='train',transform=preprocess,download=True)
            test_data=torchvision.datasets.StanfordCars("./",split='test',transform=preprocess_test,download=True)
            
        else:
            assert False
        #train_data = COOPLMDBDataset(root = data_path, split="train", transform = preprocess)
        #test_data = COOPLMDBDataset(root = data_path, split="test", transform = preprocess)
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=8),
        }
        configs = {
            'class_names': refine_classnames(test_data.classes),
            'mask': np.zeros((resize, resize)),
        }
    elif dataset in ["dtd", "oxfordpets"]:
        if resize==None:
            preprocess = transforms.Compose([
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
            ])
            preprocess_test=preprocess
        else:
            preprocess_test = transforms.Compose([
                transforms.Resize((256,256)),
                transforms.CenterCrop(resize),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            if dataAug is False:
                preprocess = transforms.Compose([
                transforms.Resize((resize,resize)),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.ToTensor(),
                ])
            else:
                preprocess = transforms.Compose([
                #transforms.Resize((256,256)),
                transforms.RandomResizedCrop(resize),
                transforms.RandomHorizontalFlip(),
                transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
                transforms.Resize((resize,resize)),
                transforms.ToTensor(),
                ])
        if dataset=='dtd':
            train_data=ConcatDataset([DTD(root = data_path, split = 'train', transform=preprocess, download = True), \
                                DTD(root = data_path, split = 'val', transform=preprocess, download = True)])
            test_data = DTD(data_path, split = 'test', transform=preprocess_test, download=True)
        elif dataset=='oxfordpets':
            train_data=OxfordIIITPet(data_path, split = 'trainval', transform=preprocess, download=True)
            test_data=OxfordIIITPet(data_path, split = 'test', transform=preprocess_test, download=True)
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=8),
        }
        configs = {
            'class_names': refine_classnames(test_data.classes),
            'mask': np.zeros((resize, resize)),
        }
    else:
        raise NotImplementedError(f"{dataset} not supported")
    return loaders, configs
class TinyImageNet(Dataset):
    def __init__(self, root_dir, annotations_file, label_ids_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.entries = open(annotations_file).read().strip().split('\n')

        with open(label_ids_file, 'r') as f:
            self.label_names = [l.strip() for l in f.readlines()]
        self.label_names = sorted(self.label_names)
        self.label_idx = {name: idx for idx,name in enumerate(self.label_names)}

        
    def __len__(self):
        return len(self.entries)
    
    def __getitem__(self, index):
        line = self.entries[index].split('\t')
        img_path, annotation = line[0], line[1]
        image = Image.open(self.root_dir + '/' + img_path).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
            
        return image, int(self.label_idx[annotation])
def get_tinyimagenet_dataloaders(args,resize=None,dataAug=False):
    traindir = os.path.join(args.datadir, 'train')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    # if resize==None:
    #     preprocess = transforms.Compose([
    #         transforms.Resize(256),
    #         transforms.CenterCrop(224),
    #         transforms.Resize((resize,resize)),
    #         transforms.ToTensor(),
    #         ])
    #     preprocess_test=preprocess
    # else:
    #     preprocess_test = transforms.Compose([
    #         transforms.Resize((resize,resize)),
    #         transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
    #         transforms.ToTensor(),
    #         ])
    #     if dataAug is False:
    #         preprocess = transforms.Compose([
    #         transforms.Resize(256),
    #         transforms.CenterCrop(224),
    #         transforms.Resize((resize,resize)),
    #         transforms.ToTensor(),
    #         ])
    #     else:
    #         preprocess = transforms.Compose([
    #         transforms.Resize((252,252)),
    #         transforms.RandomCrop(224),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
    #         transforms.Resize((resize,resize)),
    #         transforms.ToTensor(),
    #         ])
    if resize==None:
        preprocess = transforms.Compose([
            transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
            transforms.ToTensor(),
        ])
        preprocess_test=preprocess
    else:
        preprocess_test = transforms.Compose([
            transforms.Resize((resize,resize)),
            transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
            transforms.ToTensor(),
            ])
        if dataAug is False:
            preprocess = transforms.Compose([
            transforms.Resize((resize,resize)),
            transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
            transforms.ToTensor(),
            ])
        else:
            preprocess = transforms.Compose([
            transforms.Resize((int(resize*9/8), int(resize*9/8))),
            transforms.RandomCrop(resize),
            transforms.RandomHorizontalFlip(),
            transforms.Lambda(lambda x: x.convert('RGB') if hasattr(x, 'convert') else x),
            transforms.ToTensor(),
            ])
    train_dataset = datasets.ImageFolder(
        traindir,
        preprocess)

    #if args.distributed:
    #    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    #else:
    train_sampler = None

    #train_loader = torch.utils.data.DataLoader(
    #    train_dataset, batch_size=128, shuffle=(train_sampler is None),
    #    num_workers=18, pin_memory=True, sampler=train_sampler)
    val_dataset=TinyImageNet(os.path.join(args.datadir, 'val/images'), os.path.join(args.datadir, 'val/val_annotations.txt'), 
                                os.path.join(args.datadir, 'wnids.txt'), transform=preprocess_test)
    
    #val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=128, shuffle=False,num_workers=18, pin_memory=True)
    loaders = {
        'train': DataLoader(train_dataset, 128, shuffle = True, num_workers=8),
        'test': DataLoader(val_dataset, 128, shuffle = False, num_workers=8),
    }
    configs = {
        'class_names': None,
        'mask': np.zeros((resize, resize)),
    }
    return loaders, configs


def prepare_additive_data(dataset, data_path, preprocess):
    data_path = os.path.join(data_path, dataset)
    if dataset == "cifar10":
        train_data = datasets.CIFAR10(root = data_path, train = True, download = False, transform = preprocess)
        test_data = datasets.CIFAR10(root = data_path, train = False, download = False, transform = preprocess)
        class_names = refine_classnames(test_data.classes)
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=2),
        }
    elif dataset == "cifar100":
        train_data = datasets.CIFAR100(root = data_path, train = True, download = False, transform = preprocess)
        test_data = datasets.CIFAR100(root = data_path, train = False, download = False, transform = preprocess)
        class_names = refine_classnames(test_data.classes)
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=2),
        }
    elif dataset == "svhn":
        train_data = datasets.SVHN(root = data_path, split="train", download = False, transform = preprocess)
        test_data = datasets.SVHN(root = data_path, split="test", download = False, transform = preprocess)
        class_names = [f'{i}' for i in range(10)]
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=2),
        }
    elif dataset in ["food101", "sun397", "eurosat", "ucf101", "stanfordcars", "flowers102"]:
        train_data = COOPLMDBDataset(root = data_path, split="train", transform = preprocess)
        test_data = COOPLMDBDataset(root = data_path, split="test", transform = preprocess)
        class_names = refine_classnames(test_data.classes)
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=8),
        }
    elif dataset in ["dtd", "oxfordpets"]:
        train_data = COOPLMDBDataset(root = data_path, split="train", transform = preprocess)
        test_data = COOPLMDBDataset(root = data_path, split="test", transform = preprocess)
        class_names = refine_classnames(test_data.classes)
        loaders = {
            'train': DataLoader(train_data, 64, shuffle = True, num_workers=8),
            'test': DataLoader(test_data, 64, shuffle = False, num_workers=8),
        }
    elif dataset == "gtsrb":
        train_data = datasets.GTSRB(root = data_path, split="train", download = True, transform = preprocess)
        test_data = datasets.GTSRB(root = data_path, split="test", download = True, transform = preprocess)
        class_names = refine_classnames(list(GTSRB_LABEL_MAP.values()))
        loaders = {
            'train': DataLoader(train_data, 128, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=2),
        }
    elif dataset == "abide":         
        D = ABIDE(root = data_path)
        preprocess = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224,224)),
            transforms.Normalize(IMAGENETNORMALIZE['mean'], IMAGENETNORMALIZE['std']),
        ])
        X_train, X_test, y_train, y_test = train_test_split(D.data, D.targets, test_size=0.1, stratify=D.targets, random_state=1)
        train_data = ABIDE(root = data_path, transform = preprocess)
        train_data.data = X_train
        train_data.targets = y_train
        test_data = ABIDE(root = data_path, transform = preprocess)
        test_data.data = X_test
        test_data.targets = y_test
        loaders = {
            'train': DataLoader(train_data, 64, shuffle = True, num_workers=2),
            'test': DataLoader(test_data, 64, shuffle = False, num_workers=2),
        }
        class_names = ["non ASD", "ASD"]
    else:
        raise NotImplementedError(f"{dataset} not supported")

    return loaders, class_names


def prepare_gtsrb_fraction_data(data_path, fraction, preprocess=None):
    data_path = os.path.join(data_path, "gtsrb")
    assert 0 < fraction <= 1
    new_length = int(fraction*26640)
    indices = torch.randperm(26640)[:new_length]
    sampler = SubsetRandomSampler(indices)
    if preprocess == None:
        preprocess = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
        ])
        train_data = datasets.GTSRB(root = data_path, split="train", download = True, transform = preprocess)
        test_data = datasets.GTSRB(root = data_path, split="test", download = True, transform = preprocess)
        loaders = {
            'train': DataLoader(train_data, 128, sampler=sampler, num_workers=2),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=2),
        }
        configs = {
            'class_names': refine_classnames(list(GTSRB_LABEL_MAP.values())),
            'mask': np.zeros((32, 32)),
        }
        return loaders, configs
    else:
        train_data = datasets.GTSRB(root = data_path, split="train", download = True, transform = preprocess)
        test_data = datasets.GTSRB(root = data_path, split="test", download = True, transform = preprocess)
        class_names = refine_classnames(list(GTSRB_LABEL_MAP.values()))
        loaders = {
            'train': DataLoader(train_data, 128, sampler=sampler, num_workers=2),
            'test': DataLoader(test_data, 128, shuffle = False, num_workers=2),
        }
        return loaders, class_names