from lightning import LightningDataModule
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

from gnn_experiments.config import DATA_DIR
from gnn_experiments.datamodules.subset_wrapper import SubsetWrapper
from gnn_experiments.datamodules.transforms import train_transform, val_transform
from gnn_experiments.datasets.pets import Pets

resize_transform = transforms.Compose(
    [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]
)


def pets_target_transform(targets):
    category, segmentation, detail_category = targets
    segmentation = resize_transform(segmentation)
    return category, segmentation, detail_category


class PetsDataModule(LightningDataModule):
    def __init__(
        self, data_dir: str = DATA_DIR, batch_size: int = 64, target_type="category"
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_classes = 2 if target_type == "binary-category" else 37
        self.name = f"pets_{target_type}"
        self.target_type = target_type

    def prepare_data(self) -> None:
        Pets(root=self.data_dir, split="trainval", target=self.target_type)
        Pets(root=self.data_dir, split="test", target=self.target_type)

    def setup(self, stage: str = None):
        trainval = Pets(
            root=self.data_dir,
            split="trainval",
            target_transform=pets_target_transform,
            target=self.target_type,
        )
        y_category = [item["category"] for item in trainval]

        train_idx, validation_idx = train_test_split(
            np.arange(len(trainval)), test_size=0.2, random_state=42, stratify=y_category
        )

        train_data = Subset(trainval, train_idx)
        val_data = Subset(trainval, validation_idx)

        self.train_data = SubsetWrapper(train_data, transform=train_transform)
        self.val_data = SubsetWrapper(val_data, transform=val_transform)
        self.test_data = Pets(
            root=self.data_dir,
            split="test",
            transform=val_transform,
            target_transform=pets_target_transform,
            target=self.target_type,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            batch_size=self.batch_size,
            num_workers=4,
            pin_memory=True,
            persistent_workers=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_data, batch_size=self.batch_size, num_workers=4, pin_memory=True
        )
