from argparse import ArgumentParser
import os
from pathlib import Path

import pytorch_lightning as pl
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms as T


class Flowers102DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size, num_workers=4, distributed_sampler=True):

        super().__init__()

        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.distributed_sampler = distributed_sampler

    def _create_data_loader(self, data_split):
        data_path = os.path.join(self.data_dir, data_split)
        if data_split == "train":
            data_transform = T.Compose(
                [
                    T.RandomResizedCrop(224, interpolation=3),
                    T.RandomHorizontalFlip(),
                    T.RandomRotation(45),
                    T.ToTensor(),
                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ]
            )
        else:
            data_transform = T.Compose(
                [
                    T.Resize(256),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ]
            )

        dataset = ImageFolder(
            root=data_path,
            transform=data_transform,
        )
        print("Number of {} images loaded: {}".format(data_split, len(dataset)))

        sampler = (
            torch.utils.data.DistributedSampler(dataset)
            if self.distributed_sampler
            else None
        )

        dataloader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            sampler=sampler,
        )
        return dataloader

    def train_dataloader(self):
        return self._create_data_loader(data_split="train")

    def val_dataloader(self):
        return self._create_data_loader(data_split="valid")

    def test_dataloader(self):
        return self._create_data_loader(data_split="test")

    @staticmethod
    def add_data_specific_args(parent_parser):  # pragma: no-cover
        """
        Define parameters that only apply to this model
        """
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        # dataset arguments
        parser.add_argument(
            "--data_dir",
            default=None,
            type=Path,
            help="Path to data root",
        )

        # data loader arguments
        parser.add_argument(
            "--num_workers",
            default=4,
            type=float,
            help="Number of workers to use in data loader",
        )

        return parser
