from PIL import Image
from os.path import join

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
from torch.utils.data import Dataset

from nsml import DATASET_PATH

# def get_cifar10_loader(bs=256):
# #     classes = ('plane', 'car', 'bird', 'cat', 'deer',
# #                'dog', 'frog', 'horse', 'ship', 'truck')
# 
#     train_transform = tv.transforms.Compose([
#         tv.transforms.RandomCrop(32, padding=4),
#         tv.transforms.RandomHorizontalFlip(),
#         tv.transforms.ToTensor(),
#         tv.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
#     ])
# 
#     test_transform = tv.transforms.Compose([
#         tv.transforms.ToTensor(),
#         tv.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
#     ])
# 
#     train_dataset = tv.datasets.CIFAR10(
#         root='../data', train=True, download=False, transform=train_transform)
#     train_loader = torch.utils.data.DataLoader(
#         train_dataset, batch_size=bs, shuffle=True, num_workers=2)
# 
#     test_dataset = tv.datasets.CIFAR10(
#         root='../data', train=False, download=False, transform=test_transform)
#     test_loader = torch.utils.data.DataLoader(
#         test_dataset, batch_size=100, shuffle=False, num_workers=2)
#     
#     return train_loader, test_loader
# 
# 
# def get_cifar100_loader(bs=256):
# 
#     train_transform = tv.transforms.Compose([
#         tv.transforms.RandomCrop(32, padding=4),
#         tv.transforms.RandomHorizontalFlip(),
#         tv.transforms.ToTensor(),
#         tv.transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
#     ])
# 
#     test_transform = tv.transforms.Compose([
#         tv.transforms.ToTensor(),
#         tv.transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276)),
#     ])
# 
#     train_dataset = tv.datasets.CIFAR100(
#         root='../data', train=True, download=False, transform=train_transform)
#     train_loader = torch.utils.data.DataLoader(
#         train_dataset, batch_size=bs, shuffle=True, num_workers=2)
# 
#     test_dataset = tv.datasets.CIFAR100(
#         root='../data', train=False, download=False, transform=test_transform)
#     test_loader = torch.utils.data.DataLoader(
#         test_dataset, batch_size=100, shuffle=False, num_workers=2)
#     
#     return train_loader, test_loader

def get_image_ids(metadata):
    """
    image_ids.txt has the structure

    <path>
    path/to/image1.jpg
    path/to/image2.jpg
    path/to/image3.jpg
    ...
    """
    image_ids = []
    with open(join(metadata,'image_ids.txt')) as f:
        for line in f.readlines():
            image_ids.append(line.strip('\n'))
    return image_ids
    

def get_class_labels(metadata):
    """
    class_labels.txt has the structure

    <path>,<integer_class_label>
    path/to/image1.jpg,0
    path/to/image2.jpg,1
    path/to/image3.jpg,1
    ...
    """
    class_labels = {}
    with open(join(metadata,'class_labels.txt')) as f:
        for line in f.readlines():
            image_id, class_label_string = line.strip('\n').split(',')
            class_labels[image_id] = int(class_label_string)
    return class_labels


class ImagenetTestDataset(Dataset):
    def __init__(self, transform):
        self.data_root = join(DATASET_PATH, 'train')
        self.metadata = join('metadata','ILSVRC','test')
        self.transform = transform
        self.image_ids = get_image_ids(self.metadata)
        self.image_labels = get_class_labels(self.metadata)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_label = self.image_labels[image_id]
        image = Image.open(join(self.data_root, image_id))
        image = image.convert('RGB')
        image = self.transform(image)
        return image, image_label

    def __len__(self):
        return len(self.image_ids)


def get_imagenet_loader(bs=32):
    train_dataset = tv.datasets.ImageFolder(
        join(DATASET_PATH,'train','train'),
        tv.transforms.Compose([
            tv.transforms.RandomResizedCrop(224),
            tv.transforms.RandomHorizontalFlip(),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]),
        ]))
    test_dataset = ImagenetTestDataset(
        tv.transforms.Compose([
            tv.transforms.Resize(256),
            tv.transforms.CenterCrop(224),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]),
        ]))
#     test_dataset = tv.datasets.ImageFolder(
#         join(root_path,'val'), 
#         tv.transforms.Compose([
#             tv.transforms.Resize(256),
#             tv.transforms.CenterCrop(224),
#             tv.transforms.ToTensor(),
#             tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                     std=[0.229, 0.224, 0.225]),
#         ]))
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=bs, shuffle=True,
        num_workers=16, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=bs, shuffle=False,
        num_workers=16, pin_memory=True)

    return train_loader, test_loader
