from pathlib import Path
from typing import Callable, Optional, Tuple

import lightning as L
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset


class FFHQ(Dataset):
    def __init__(self, root, file_list, corruption, transform=None):
        self.transform = transform
        self.corruption = corruption

        with open(Path(root) / file_list, "r") as f:
            files = f.read().splitlines()
        self.image_paths = [Path(root) / file for file in files]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = np.array(Image.open(img_path).convert("RGB"))

        if self.transform:
            image = self.transform(image)

        degraded = self.corruption(image)

        image = 2 * image - 1
        degraded = 2 * degraded - 1

        return {"y": degraded, "x0": image}


class FFHQDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        corruption: Callable,
        train_transform: Optional[Callable] = None,
        val_transform: Optional[Callable] = None,
        split: Tuple = (0.98, 0.002, 0.018),  # Deprecated NOT USED
        train_bsz: int = 16,
        val_bsz: int = 8,
        num_workers: int = 8,
    ):
        super().__init__()
        self.train_transform = train_transform
        self.val_transform = val_transform
        self.corruption = corruption
        self.data_dir = data_dir
        self.train_bsz = train_bsz
        self.val_bsz = val_bsz
        self.num_workers = num_workers
        self.split = split

    def setup(self, stage=None):
        self.train_dataset = FFHQ(
            root=self.data_dir,
            file_list="train_file_list.txt",
            corruption=self.corruption,
            transform=self.train_transform,
        )

        self.val_dataset = FFHQ(
            root=self.data_dir,
            file_list="val_file_list.txt",
            corruption=self.corruption,
            transform=self.val_transform,
        )

        self.test_dataset = FFHQ(
            root=self.data_dir,
            file_list="test_file_list.txt",
            corruption=self.corruption,
            transform=self.val_transform,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_bsz,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_bsz,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.val_bsz,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def get_inception_statistics(self):
        return np.load(Path(self.data_dir) / "full1k.npz")
