import numpy as np
import torch
from lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset, random_split

from src.datamodules.subset_wrapper import SubsetWrapper
from src.datamodules.transforms import val_transform, train_transform
from src.datasets.imagenet import ImageNet


class ImagenetDataModule(LightningDataModule):
    def __init__(self, data_dir: str = 'data/', batch_size: int = 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_classes = 1000
        self.name = 'imagenet'

    def prepare_data(self) -> None:
        ImageNet(root=self.data_dir, split='train')
        ImageNet(root=self.data_dir, split='val')

    def setup(self, stage: str = None):
        trainval = ImageNet(root=self.data_dir, split='train')
        length = len(trainval)
        train_length = int(length * 0.99)
        val_length = length - train_length
        train_data, val_data = random_split(
            trainval, [train_length, val_length], generator=torch.Generator().manual_seed(42)
        )
        self.train_data = SubsetWrapper(train_data, transform=train_transform)
        self.val_data = SubsetWrapper(val_data, transform=val_transform)
        self.test_data = ImageNet(root=self.data_dir, split='val', transform=val_transform)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=8, pin_memory=True, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=8, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=8, pin_memory=True)
