from pathlib import Path
from typing import Callable, Optional

import lightning as L
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset


class Rain1400(Dataset):
    def __init__(
        self,
        root: str,
        file_list: str,
        transform: Optional[Callable] = None,
    ):
        self.transform = transform
        with open(Path(root) / file_list, "r") as f:
            files = f.read().splitlines()

        self.low_file_list = [Path(root) / (file.split(" ")[0]) for file in files]
        self.high_file_list = [Path(root) / (file.split(" ")[1]) for file in files]

    def __len__(self):
        return len(self.low_file_list)

    def __getitem__(self, index: int):
        image = np.array(Image.open(str(self.high_file_list[index])).convert("RGB"))
        degraded = np.array(Image.open(str(self.low_file_list[index])).convert("RGB"))

        if self.transform:
            image, degraded = self.transform((image, degraded))

        image = 2 * image - 1
        degraded = 2 * degraded - 1

        return {
            "y": degraded,
            "x0": image,
        }


class Rain1400DataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        train_transform: Callable,
        val_transform: Callable,
        train_bsz: int = 16,
        val_bsz: int = 8,
        num_workers: int = 8,
    ):
        super().__init__()
        self.train_transform = train_transform
        self.val_transform = val_transform
        self.data_dir = Path(data_dir)
        self.train_batch_size = train_bsz
        self.val_batch_size = val_bsz
        self.num_workers = num_workers

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = Rain1400(
                self.data_dir,
                "train_file_list.txt",
                transform=self.train_transform,
            )

        self.val_dataset = Rain1400(
            self.data_dir,
            "val_file_list.txt",
            transform=self.val_transform,
        )

        self.test_dataset = Rain1400(
            self.data_dir,
            "test_file_list.txt",
            transform=self.val_transform,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def get_inception_statistics(self):
        return None  # no FID calculation as the dataset is too small for that
