from lightning import LightningDataModule
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

from gnn_experiments.config import DATA_DIR
from gnn_experiments.datamodules.subset_wrapper import SubsetWrapper
from gnn_experiments.datamodules.transforms import train_transform, val_transform
from gnn_experiments.datasets.cub import CUB200

segmentation_transform = transforms.Compose(
    [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
)


class CubDataModule(LightningDataModule):
    def __init__(
        self, data_dir: str = DATA_DIR, batch_size: int = 64
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.name = "cub"
        self.num_classes = 200

    def prepare_data(self) -> None:
        CUB200(root=self.data_dir, split="train")
        CUB200(root=self.data_dir, split="test")

    def setup(self, stage: str = None):
        trainval = CUB200(
            root=self.data_dir,
            split="train",
            segmentation_transform=segmentation_transform,
        )
        y_category = [item["y"] for item in trainval]

        train_idx, validation_idx = train_test_split(
            np.arange(len(trainval)), test_size=0.2, random_state=42, stratify=y_category
        )

        train_data = Subset(trainval, train_idx)
        val_data = Subset(trainval, validation_idx)

        self.train_data = SubsetWrapper(train_data, transform=train_transform)
        self.val_data = SubsetWrapper(val_data, transform=val_transform)
        self.test_data = CUB200(
            root=self.data_dir,
            split="test",
            transform=val_transform,
            segmentation_transform=segmentation_transform,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_data, batch_size=self.batch_size, num_workers=4, pin_memory=True
        )
