from fastai.vision import *
from fastprogress import fastprogress

from experiments.datasets import DataLoaders

fastprogress.MAX_COLS = 80

lbl_dict = dict(
    n01440764='tench',
    n02102040='English springer',
    n02979186='cassette player',
    n03000684='chain saw',
    n03028079='church',
    n03394916='French horn',
    n03417042='garbage truck',
    n03425413='gas pump',
    n03445777='golf ball',
    n03888257='parachute'
)


def get_data(*, size, woof, bs, val_bs, test_only=False):
    if size <= 128:
        path = URLs.IMAGEWOOF_160 if woof else URLs.IMAGENETTE_160
    elif size <= 224:
        path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320
    else:
        path = URLs.IMAGEWOOF if woof else URLs.IMAGENETTE
    path = untar_data(path)

    workers = 8

    if not test_only:
        databunch: ImageDataBunch = (ImageList.from_folder(path).split_by_folder(valid='val')
                     .label_from_folder().transform(get_transforms(), size=size)
                     .databunch(bs=bs, val_bs=val_bs, num_workers=workers)
                     .presize(size, scale=(0.35, 1))
                     .normalize(imagenet_stats))
    else:
        databunch: ImageDataBunch = (ImageList.from_folder(path).split_by_folder(train='val', valid='val')
                                     .label_from_folder().transform(get_transforms(), size=size)
                                     .databunch(bs=bs, val_bs=val_bs, num_workers=workers)
                                     .presize(size, scale=(0.35, 1))
                                     .normalize(imagenet_stats))

    databunch.label_list.lists[0].y.classes = list(lbl_dict.values())
    databunch.label_list.lists[1].y.classes = list(lbl_dict.values())

    train_dl = databunch.train_dl
    test_dl = databunch.valid_dl

    if not test_only:
        eval_databunch: ImageDataBunch = (ImageList.from_folder(path).split_by_folder(train='train', valid='train')
                     .label_from_folder().transform(([], []), size=size)
                     .databunch(bs=bs, val_bs=val_bs, num_workers=workers)
                     .presize(size, scale=(0.35, 1))
                     .normalize(imagenet_stats))
    else:
        eval_databunch: ImageDataBunch = (ImageList.from_folder(path).split_by_folder(train='val', valid='val')
                     .label_from_folder().transform(([], []), size=size)
                     .databunch(bs=bs, val_bs=val_bs, num_workers=workers)
                     .presize(size, scale=(0.35, 1))
                     .normalize(imagenet_stats))

    eval_databunch.label_list.lists[0].y.classes = list(lbl_dict.values())

    train_eval_dl = eval_databunch.train_dl
    return train_dl, test_dl, train_eval_dl


def dataloaders(train_batch_size, test_batch_size, test_only=False) -> DataLoaders:
    train_dl, test_dl, train_eval_dl = get_data(size=224, woof=False, bs=train_batch_size, val_bs=test_batch_size, test_only=test_only)
    return DataLoaders(train=train_dl, test=test_dl, train_eval=train_eval_dl)
