from pathlib import Path
from typing import Callable, Literal, Optional, Tuple

import lightning as L
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset


class ImageNet(Dataset):
    def __init__(
        self,
        root: str,
        split: Literal["train", "val"],
        transform: Optional[Callable] = None,
        corruption: Optional[Callable] = None,
        file_list: Optional[str] = None,
    ):
        self.corruption = corruption
        self.transform = transform

        if file_list is not None:
            with open(Path(root) / file_list, "r") as f:
                files = f.read().splitlines()
            self.file_list = [Path(root) / file for file in files]
        else:
            self.file_list = list(sorted((Path(root) / split).rglob("*.JPEG")))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index: int):
        image = np.array(Image.open(str(self.file_list[index])).convert("RGB"))

        if self.transform:
            image = self.transform(image)

        if self.corruption:
            degraded = self.corruption(image)

        image = 2 * image - 1
        degraded = 2 * degraded - 1

        return {
            "y": degraded,
            "x0": image,
        }


class ImageNetDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        corruption: Callable,
        train_transform: Callable,
        val_transform: Callable,
        test_transform: Optional[Callable] = None,
        val_test_split: Tuple = (0.01, 0.99),
        train_bsz: int = 16,
        val_bsz: int = 8,
        num_workers: int = 8,
    ):
        super().__init__()
        self.train_transform = train_transform
        self.val_transform = val_transform
        self.test_transform = (
            test_transform if test_transform is not None else val_transform
        )
        self.corruption = corruption
        self.data_dir = data_dir
        self.val_test_split = val_test_split
        self.train_batch_size = train_bsz
        self.val_batch_size = val_bsz
        self.num_workers = num_workers

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = ImageNet(
                root=self.data_dir,
                split="train",
                transform=self.train_transform,
                corruption=self.corruption,
                file_list="train_file_list.txt",
            )

            self.val_dataset = ImageNet(
                root=self.data_dir,
                split="val",
                transform=self.val_transform,
                corruption=self.corruption,
                file_list="val_file_list.txt",
            )

            self.test_dataset = ImageNet(
                root=self.data_dir,
                split="val",
                transform=self.test_transform,
                corruption=self.corruption,
                file_list="test_file_list.txt",
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.val_batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def get_inception_statistics(self):
        return np.load(Path(self.data_dir) / "full50k.npz")
