import os
import math
import torch
import torchvision.datasets as datasets
import numpy as np
from copy import deepcopy
from PIL import Image
from . import samplers,transform_manager


def get_dataset(data_path,is_training,transform_type,pre):

    dataset = datasets.ImageFolder(
        data_path,
        loader = lambda x: image_loader(path=x,is_training=is_training,transform_type=transform_type,pre=pre))

    return dataset



def meta_train_dataloader(data_path,way,shots,transform_type):

    dataset = get_dataset(data_path=data_path,is_training=True,transform_type=transform_type,pre=None)

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler = samplers.meta_batchsampler(data_source=dataset,way=way,shots=shots),
        num_workers = 8,
        pin_memory = False)

    return loader



def meta_test_dataloader(data_path,way,shot,pre,transform_type=None,query_shot=16,trial=1000):

    dataset = get_dataset(data_path=data_path,is_training=False,transform_type=transform_type,pre=pre)

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler = samplers.random_sampler(data_source=dataset,way=way,shot=shot,query_shot=query_shot,trial=trial),
        num_workers = 8,
        pin_memory = False)

    return loader


def normal_train_dataloader(data_path,batch_size,transform_type):

    dataset = get_dataset(data_path=data_path,is_training=True,transform_type=transform_type,pre=None)

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size = batch_size,
        shuffle = True,
        num_workers = 8,
        pin_memory = False,
        drop_last=True)

    return loader


def image_loader(path,is_training,transform_type,pre):

    p = Image.open(path)
    p = p.convert('RGB')

    final_transform = transform_manager.get_transform(is_training=is_training,transform_type=transform_type,pre=pre)

    p = final_transform(p)

    return p
