import os
import torch
import pickle
import collections
import torchvision
import torchvision.transforms as transforms
import datasets.word_classification as torch_cls_datasets
import datasets.decathlon_datasets as decathlon_datasets

def create_transforms(args):
    if 'vit' in args.backbone:
        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5])
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if args.cropped:
        test_transform = transforms.Compose([
            transforms.Resize([224,224]),
            transforms.ToTensor(),
            normalize,
        ])
        train_transform = transforms.Compose([
            transforms.Resize([224,224]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else: 
        test_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])

        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    return train_transform, test_transform

def create_decathlon_transforms(args):
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    test_transform = transforms.Compose([
        transforms.Resize((72, 72)),
        # transforms.CenterCrop(64),
        transforms.ToTensor(),
        normalize,
    ])
    train_transform = transforms.Compose([
        transforms.Resize((72, 72)),
        # transforms.RandomHorizontalFlip(),
        # transforms.CenterCrop(64),
        transforms.ToTensor(),
        normalize,
    ])
    return train_transform, test_transform

def load_imagenet2sketch_benchmark(args):
    train_transform, test_transform = create_transforms(args)
    train_sets = []
    val_sets = []
    dataset_names = os.listdir(args.dataset_name)
    dataset_names.sort()
    # dataset_names = dataset_names[1:]
    num_classes = []

    #Concatenate all datasets together
    for j, dataset_name in enumerate(dataset_names):
        train_dataset, val_dataset, classes = \
            load_imagenet2sketch_benchmark_dataset(
                os.path.join(args.dataset_name, dataset_name), 
                train_transform, test_transform, j)
        num_classes.append(classes)
        train_sets.append(train_dataset)
        val_sets.append(val_dataset)
    
    return train_sets, val_sets, num_classes, dataset_names

def load_tf_visual_domain_decathlon_benchmark(args):
    # train_transform, test_transform = create_decathlon_transforms(args)
    train_sets = []
    val_sets = []
    num_classes = []
    with open(os.path.join(args.dataset_name , 'decathlon_mean_std.pickle'), 'rb') as handle:
        dict_mean_std = pickle.load(handle, encoding="iso-8859-1")
        
    for name in torch_cls_datasets.TORCH_IMAGE_CLASSIFCATON_DATASETS:
        train_set, val_set = \
            torch_cls_datasets.TORCH_CLS_Dataset(name, 'train', path=args.dataset_name),torch_cls_datasets.TORCH_CLS_Dataset(name, 'validation', path=args.dataset_name)
        
        means = dict_mean_std[name + 'mean']
        stds = dict_mean_std[name + 'std']
        
        if name in ['gtsrb', 'omniglot','svhn']: 
            transform_train = transforms.Compose([
            transforms.Resize(72),
            transforms.CenterCrop(72),
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])
        elif name in ['daimlerpedcls']:
            transform_train = transforms.Compose([
            transforms.Resize(72),            
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])  
        else:
            transform_train = transforms.Compose([
            transforms.Resize(72),            
            transforms.RandomCrop(64),
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])  

        if name in ['gtsrb', 'omniglot','svhn']: 
            transform_test = transforms.Compose([
            transforms.Resize(72),
            transforms.CenterCrop(72),
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])
        elif name in ['daimlerpedcls']:
            transform_test = transforms.Compose([
            transforms.Resize(72),            
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])  
        else:
            transform_test = transforms.Compose([
                transforms.Resize(72),
                transforms.CenterCrop(72),
                transforms.ToTensor(),
                transforms.Normalize(means, stds),
            ])
		
        
        train_set.set_transforms(transform_train)
        val_set.set_transforms(transform_test)
        train_sets.append(train_set)
        val_sets.append(val_set)
        num_classes.append(train_set.num_classes)
        print(f"Dataset {name} has {len(train_set)} training pics, {len(val_set)} validation pics.")
    return train_sets, val_sets, num_classes, torch_cls_datasets.TORCH_IMAGE_CLASSIFCATON_DATASETS  

def load_visual_domain_decathlon_benchmark(args):
    datasets = [
        "aircraft",
        "cifar100",
        "daimlerpedcls",
        "dtd",
        "gtsrb",
        "omniglot",
        "svhn",
        "ucf101",
        "vgg-flowers"
        ]

    # datasets = collections.OrderedDict(datasets)
    train_loaders, val_loaders, num_classes, dataset_names = \
        decathlon_datasets.prepare_data_loaders(
            datasets, 
            args.dataset_name, 
            os.path.join(args.dataset_name, "annotations/"), 
            True
        )
    return train_loaders, val_loaders, num_classes, dataset_names
    

def load_imagenet2sketch_benchmark_dataset(path, train_transform, test_transform, idx):
    train_path = os.path.join(path) + '/train'
    test_path = os.path.join(path) + '/test'
    train_dataset = torchvision.datasets.ImageFolder(train_path, transform = train_transform)
    val_dataset = torchvision.datasets.ImageFolder(test_path, transform = test_transform)
    train_dataset.samples = [(x, (label, idx)) for x, label in train_dataset.samples]
    val_dataset.samples = [(x, (label, idx)) for x, label in val_dataset.samples]
    print(path, " num_classes: ", len(train_dataset.classes))
    return train_dataset, val_dataset, len(train_dataset.classes)