import numpy as np
import torch
from lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

from src.datasets.cars import Cars
from src.datamodules.subset_wrapper import SubsetWrapper
from src.datamodules.transforms import train_transform, val_transform


class CarsDataModule(LightningDataModule):
    def __init__(self, data_dir: str = 'data/', batch_size: int = 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_classes = 196
        self.name = 'cars'

    def prepare_data(self) -> None:
        Cars(root=self.data_dir, split='train')
        Cars(root=self.data_dir, split='test')

    def setup(self, stage: str = None):
        trainval = Cars(root=self.data_dir, split='train')
        y_trainval = [item['y'] for item in trainval]

        train_idx, validation_idx = train_test_split(np.arange(len(trainval)),
                                                     test_size=0.1,
                                                     random_state=42,
                                                     stratify=y_trainval)

        train_data = Subset(trainval, train_idx)
        val_data = Subset(trainval, validation_idx)
        self.train_data = SubsetWrapper(train_data, transform=train_transform)
        self.val_data = SubsetWrapper(val_data, transform=val_transform)
        self.test_data = Cars(root=self.data_dir, split='test', transform=val_transform)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=8, pin_memory=True, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=8, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=8, pin_memory=True)
