import torch
import torchvision
import torchvision.transforms as transforms
from utils import *
import os


def dataloader_tiny_image_net(batch_size, root=None, ctg='train'):
    data_dir = '/'.join([root, 'tiny-imagenet-200', ctg])

    if ctg == 'train':
        transform = transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])


    dataset = torchvision.datasets.ImageFolder(data_dir, transform=transform)
    datasampler = torch.utils.data.distributed.DistributedSampler(
        dataset, shuffle=(True if ctg == 'train' else False),
        drop_last=(False if ctg == 'train' else True))
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False,
        num_workers=2, pin_memory=True, sampler=datasampler)

    return datasampler, dataloader

def dataloader_tiny_image_net_supcon(batch_size, root=None, ctg='train'):
    data_dir = '/'.join([root, 'tiny-imagenet-200', ctg])

    if ctg == 'train':
        transform = transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        transform = TwoCropTransform(transform)
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])


    dataset = torchvision.datasets.ImageFolder(data_dir, transform=transform)
    datasampler = torch.utils.data.distributed.DistributedSampler(
        dataset, shuffle=(True if ctg == 'train' else False),
        drop_last=(False if ctg == 'train' else True))
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False,
        num_workers=2, pin_memory=True, sampler=datasampler)

    return datasampler, dataloader