import numpy as np
import torch
from lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from torch.utils.data import random_split, DataLoader, Subset

from src.datamodules.subset_wrapper import SubsetWrapper
from src.datamodules.transforms import val_transform, train_transform
from src.datasets.imagenette import Imagenette


class ImagenetteDataModule(LightningDataModule):
    def __init__(self, data_dir: str = 'data/', batch_size: int = 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_classes = 10
        self.name = 'imagenette'

    def prepare_data(self) -> None:
        Imagenette(root=self.data_dir, split='train')
        Imagenette(root=self.data_dir, split='val')

    def setup(self, stage: str = None):
        trainval = Imagenette(root=self.data_dir, split='train')
        y_trainval = [item['y'] for item in trainval]

        train_idx, validation_idx = train_test_split(np.arange(len(trainval)),
                                                     test_size=969,
                                                     random_state=42,
                                                     stratify=y_trainval
                                                     )

        train_data = Subset(trainval, train_idx)
        val_data = Subset(trainval, validation_idx)

        self.train_data = SubsetWrapper(train_data, transform=train_transform)
        self.val_data = SubsetWrapper(val_data, transform=val_transform)
        self.test_data = Imagenette(root=self.data_dir, split='val', transform=val_transform)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=8, pin_memory=True, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=8, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=8, pin_memory=True)
