import logging
import numpy as np
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets


module_logger = logging.getLogger('__main__.utils.dataset_util')


def gen_datasets(datadir):
    """Generate datasets for experiment.

    Preprocessing from pytorch Imagenet example code

    Parameters
    ----------
    datadir : string
        path to directory of data

    Returns
    -------
    Tuple of torchvision.datasets.Dataset objects:
        val_dataset, ood_dataset
    """
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = datasets.ImageFolder(
        os.path.join(datadir, 'train'),
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    eval_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    val_dataset = datasets.ImageFolder(os.path.join(datadir, 'val'),
                                       eval_transform)
    ood_dataset = datasets.ImageFolder(os.path.join(datadir, 'ood'),
                                       eval_transform)
    return train_dataset, val_dataset, ood_dataset


def gen_far_ood_datasets(dset: str = "iNaturalist"):
    if dset not in ['iNaturalist', 'SUN', 'Places', 'Textures',
                    'coarseid-fineood', 'coarseid-coarseood',
                    'imagenet1000-fineood', 'imagenet1000-mediumood',
                    'imagenet1000-coarseood',
                    ]:
        raise ValueError("Unknown far ood dataset: " + dset)
    datadir = "data/" + dset
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    ds = datasets.ImageFolder(datadir, transform)
    return ds


def print_stats_of_list(prefix,dat):
    # Helper to print min/max/avg/std/len of values in a list
    dat = np.array(dat)
    logger.info("{} Min: {:.4f}; Max: {:.4f}; Avg: {:.4f}; Std: {:.4f}; Len: {}".format(
            prefix, dat.min(), dat.max(), dat.mean(), dat.std(), len(dat))
    )

