from pathlib import Path
from typing import Callable, Optional

import lightning as L
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset

VAL_SPLIT_FILES = ["265.png", "732.png", "45.png", "544.png", "99.png"]


class LOL(Dataset):
    def __init__(
        self,
        low_file_list: list[Path],
        high_file_list: list[Path],
        transform: Optional[Callable] = None,
    ):
        self.transform = transform
        self.low_file_list = low_file_list
        self.high_file_list = high_file_list

    def __len__(self):
        return len(self.low_file_list)

    def __getitem__(self, index: int):
        image = np.array(Image.open(str(self.high_file_list[index])).convert("RGB"))
        degraded = np.array(Image.open(str(self.low_file_list[index])).convert("RGB"))

        if self.transform:
            image, degraded = self.transform((image, degraded))

        target_lightness = image.mean()

        image = 2 * image - 1
        degraded = 2 * degraded - 1

        return {
            "y": degraded,
            "x0": image,
            "target_lightness": target_lightness,
        }


class LOLDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        train_transform: Callable,
        val_transform: Callable,
        train_bsz: int = 16,
        val_bsz: int = 8,
        num_workers: int = 8,
    ):
        super().__init__()
        self.train_transform = train_transform
        self.val_transform = val_transform
        self.data_dir = Path(data_dir)
        self.train_batch_size = train_bsz
        self.val_batch_size = val_bsz
        self.num_workers = num_workers

    def setup(self, stage=None):
        # if stage == "fit" or stage is None:
        our_low = list(sorted((self.data_dir / "our485" / "low").rglob("*.png")))
        our_high = list(sorted((self.data_dir / "our485" / "high").rglob("*.png")))

        train_low = []
        train_high = []
        val_low = []
        val_high = []

        for low, high in zip(our_low, our_high):
            if low.name in VAL_SPLIT_FILES:
                val_low.append(low)
                val_high.append(high)
            else:
                train_low.append(low)
                train_high.append(high)

        self.train_dataset = LOL(
            low_file_list=train_low,
            high_file_list=train_high,
            transform=self.train_transform,
        )
        self.val_dataset = LOL(
            low_file_list=val_low,
            high_file_list=val_high,
            transform=self.val_transform,
        )

        eval_low = list(sorted((self.data_dir / "eval15" / "low").rglob("*.png")))
        eval_high = list(sorted((self.data_dir / "eval15" / "high").rglob("*.png")))
        self.test_dataset = LOL(
            low_file_list=eval_low,
            high_file_list=eval_high,
            transform=self.val_transform,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def get_inception_statistics(self):
        return None  # no FID calculation as the dataset is too small for that
