from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
import glob
from PIL import Image

from convexrobust.utils import dirs

# Adapted from
# https://www.kaggle.com/code/bootiu/dog-vs-cat-transfer-learning-by-pytorch-lightning/notebook


class KaggleCatsDogsDataModule(pl.LightningDataModule):
    def __init__(self, train_transforms, val_transforms, test_transforms,
                 batch_size=64, num_workers=4, shuffle=True):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle = shuffle

        self.data_trans = {
            'train': train_transforms, 'val': val_transforms, 'test': test_transforms
        }

        self.in_n = 3 * 224 * 224
        self.seed = 42

    def setup(self, stage=None):
        train_img_path = glob.glob(dirs.data_path(
            'dogs-vs-cats-redux-kernels-edition', 'train', '*.jpg'))
        train_img_path, test_img_path = train_test_split(
            train_img_path, test_size=0.1, random_state=self.seed)
        train_img_path, val_img_path = train_test_split(
            train_img_path, test_size=0.2, random_state=self.seed)

        self.dataset_train = CatsDogsDataset(train_img_path, self.data_trans, phase='train')
        self.dataset_val = CatsDogsDataset(val_img_path, self.data_trans, phase='val')
        self.dataset_test = CatsDogsDataset(test_img_path, self.data_trans, phase='test')

    def train_dataloader(self):
        return DataLoader(self.dataset_train, batch_size=self.batch_size,
                          num_workers=self.num_workers, shuffle=self.shuffle)

    def val_dataloader(self):
        return DataLoader(self.dataset_val, batch_size=self.batch_size,
                          num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.dataset_test, batch_size=1, num_workers=self.num_workers)


class CatsDogsDataset(Dataset):
    def __init__(self, file_list, data_transform, phase='train'):
        self.file_list = file_list
        self.data_transform = data_transform
        self.phase = phase

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.data_transform[self.phase](img)

        label = img_path.split('/')[-1].split('.')[0]
        label = 1 if label == 'dog' else 0

        return img_transformed, label
