import os.path as osp

import numpy as np

from .utils import load_pickle
from .datasets.TestDataLoader import TestDataLoader
from .datasets.MetaTrainDataLoader import MetaTrainDataLoader

ospj = osp.join
ospeu = osp.expanduser


def get_file_dir(name, dataset_path=None):
    #######
    # CIL #
    #######
    if name == 'CIFAR-100':
        im_dir = ospeu(osp.join(dataset_path, 'cifar-100'))
        partition_file = ospeu(osp.join(im_dir, 'cifar-100.pkl'))

    elif name == 'TinyImageNet':
        im_dir = ospeu(osp.join(dataset_path, 'tiny-imagenet-200'))
        partition_file = ospeu(osp.join(im_dir, 'tiny-imagenet-200.pkl'))

    elif name == 'ImageNet-100':
        im_dir = ospeu(osp.join(dataset_path, 'ImageNet-100'))
        partition_file = ospeu(osp.join(im_dir, 'ImageNet-100.pkl'))
        
    elif name == 'ImageNet-1000':
        im_dir = ospeu(osp.join(dataset_path, 'ImageNet-1000'))
        partition_file = ospeu(osp.join(im_dir, 'ImageNet-1000.pkl'))

    return im_dir, partition_file
    

def create_dataloader(name='CIFAR-100', order_file=None, part='phase_train', phase=5, dataset_path=None, **kwargs):
    assert name in ['CIFAR-100', 'TinyImageNet', 'ImageNet-100', 'ImageNet-1000'], "Unsupported Dataset {}".format(name)

    assert part in ['phase_train', 'up2now_test'], "Unsupported Dataset Part {}".format(part)

    assert osp.exists(dataset_path), "The dataset path "+dataset_path+" does not exist!"

    ########################################
    # Specify Directory and Partition File #
    ########################################

    im_dir, partition_file = get_file_dir(name, dataset_path)

    ##################
    # Create Dataset #
    ##################
    tmp_part = 'train' if 'train' in part else 'test'
    partitions = load_pickle(partition_file)
    im_names = partitions['{}_im_names'.format(tmp_part)]
    im_ids = partitions['{}_im_ids'.format(tmp_part)]
    ids2labels = partitions['{}_ids2labels'.format(tmp_part)]


    if part == 'phase_train':
        ret_set = MetaTrainDataLoader(
            phase=phase,
            order_file=order_file,
            im_dir=im_dir,
            im_names=im_names,
            im_ids=im_ids,
            ids2labels=ids2labels,
            **kwargs)

    elif part == 'up2now_test':
        ret_set = TestDataLoader(
            phase=phase,
            order_file=order_file,
            im_dir=im_dir,
            im_names=im_names,
            im_ids=im_ids,
            ids2labels=ids2labels,
            **kwargs)

    print('-' * 40)
    return ret_set
