from torch.utils.data import DataLoader
import torch

from dataset.imagenet import ImageNet, TieredImageNet, MetaDataset
from meta_labeler import FlatDataset
import os
from tqdm import tqdm

def parse_option(opt):
    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list(map(int, iterations))
    opt.model_name = f'{opt.model}_{opt.dataset}_f{opt.train_db_size}_e{opt.epochs}'

    if opt.data_aug:
        opt.model_name += "_aug"

    if opt.no_replacement:
        opt.model_name += "_replace"
    opt.model_name += f"_{opt.trial}"

    opt.n_gpu = torch.cuda.device_count()
    opt.data_root = os.path.expanduser(opt.data_root)

    return opt


def get_dataset(opt):
    if opt.dataset == 'miniImageNet':
        meta_trainloader = DataLoader(MetaDataset(ImageNet(opt, partition="train"), args=opt, sample_shape=opt.sample_shape,
                                                  db_size=opt.train_db_size, no_replacement=opt.no_replacement),
                                      batch_size=opt.batch_size, shuffle=True, drop_last=False,
                                      num_workers=opt.num_workers)

        meta_valloader = DataLoader(MetaDataset(ImageNet(opt, partition="val"), args=opt, db_size=opt.test_db_size),
                                      batch_size=1, shuffle=False, drop_last=False,
                                      num_workers=opt.num_workers)
        n_cls = 64
    elif opt.dataset == 'tieredImageNet':
        meta_trainloader = DataLoader(
            MetaDataset(TieredImageNet(opt, partition="train"), opt, db_size=opt.train_db_size,
                        sample_shape=opt.sample_shape, no_replacement=opt.no_replacement),
            batch_size=opt.batch_size, shuffle=True, drop_last=False,
            num_workers=opt.num_workers)

        meta_valloader = DataLoader(
            MetaDataset(TieredImageNet(opt, partition="val"), opt, db_size=opt.test_db_size),
            batch_size=1, shuffle=False, drop_last=False,
            num_workers=opt.num_workers)
        n_cls = 351
    else:
        raise NotImplementedError(opt.dataset)

    return meta_trainloader, meta_valloader, n_cls


def flatten_dataset(meta_trainloader, labeler, opt):
    # logger = util.get_logger(opt.logger_name, file_name=f"{opt.logger_name}_{opt.model_name}")

    flat_db = FlatDataset([3, 84, 84])

    for id, batch_data in enumerate(tqdm(meta_trainloader)):
        task_data = list(map(lambda x: x[0], batch_data))
        xs = task_data[0]

        pseudo_ys, pseudo_cls = labeler.label_samples(xs.cuda())
        if pseudo_cls is not None:
            flat_db.add_task(xs, pseudo_ys.cpu())

    flat_db.merge_tasks()
    return flat_db


