from pathlib import Path
from typing import Callable, Optional, Tuple

import lightning as L
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset

from .imagenet import ImageNet


class LatentImageNet(Dataset):
    SHARD_SIZE = 4096

    def __init__(
        self,
        root: str,
    ):
        self.file_list = list(sorted((Path(root) / "latents").rglob("*.pt")))
        self.unused_files = self.file_list.copy()
        self.open_shard()

    def __len__(self):
        return len(self.file_list) * self.SHARD_SIZE

    def open_shard(self):
        if len(self.unused_files) == 0:
            self.unused_files = self.file_list.copy()

        shard_path = self.unused_files.pop(0)
        shard = torch.load(shard_path)
        self.current_x = shard["x0"]
        self.current_y = shard["y"]
        self.current_idx = 0

    def __getitem__(self, index: int):
        if self.current_idx >= self.current_x.shape[0]:
            self.open_shard()

        image = self.current_x[self.current_idx]
        degraded = self.current_y[self.current_idx]

        self.current_idx += 1

        return {
            "y": degraded,
            "x0": image,
            "y_latent": degraded,
            "x0_latent": image,
        }


class ValidationLatentImageNet(Dataset):
    SHARD_SIZE = 4096

    def __init__(
        self,
        root: str,
        transform: Optional[Callable],
        corruption: Optional[Callable],
        file_list: str,
    ):
        self.transform = transform
        self.corruption = corruption
        self.shard_list = list(sorted((Path(root) / "latents_test").rglob("*.pt")))
        self.shard_idx = -1

        with open(Path(root) / file_list, "r") as f:
            files = f.read().splitlines()
        self.file_list = [Path(root) / file for file in files]

    def __len__(self):
        return len(self.file_list)

    def open_shard(self, idx):
        shard = torch.load(self.shard_list[idx])
        self.current_x = shard["x0"]
        self.current_y = shard["y"]

    def load_latent(self, index: int):
        shard_idx = index // self.SHARD_SIZE

        if shard_idx != self.shard_idx:
            self.open_shard(shard_idx)
            self.shard_idx = shard_idx

        image_latent = self.current_x[index - self.shard_idx * self.SHARD_SIZE].float()
        degraded_latent = self.current_y[
            index - self.shard_idx * self.SHARD_SIZE
        ].float()

        return degraded_latent, image_latent

    def load_images(self, index: int):
        image = np.array(Image.open(str(self.file_list[index])).convert("RGB"))

        if self.transform:
            image = self.transform(image)

        if self.corruption:
            degraded = self.corruption(image)

        image = 2 * image - 1
        degraded = 2 * degraded - 1

        return degraded, image

    def __getitem__(self, index: int):
        y_latent, x0_latent = self.load_latent(index)
        y, x0 = self.load_images(index)

        return {
            "y": y,
            "x0": x0,
            "y_latent": y_latent,
            "x0_latent": x0_latent,
        }


class LatentImageNetDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        corruption: Callable,
        val_transform: Callable,
        test_transform: Optional[Callable] = None,
        val_test_split: Tuple = (0.01, 0.99),
        train_bsz: int = 16,
        val_bsz: int = 8,
        num_workers: int = 8,
    ):
        super().__init__()
        self.val_transform = val_transform
        self.test_transform = (
            test_transform if test_transform is not None else val_transform
        )
        self.corruption = corruption
        self.data_dir = data_dir
        self.val_test_split = val_test_split
        self.train_bsz = train_bsz
        self.val_bsz = val_bsz
        self.num_workers = num_workers

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_dataset = LatentImageNet(
                root=self.data_dir,
            )

        self.val_dataset = ImageNet(
            root=self.data_dir,
            split="val",
            transform=self.val_transform,
            corruption=self.corruption,
            file_list="val_file_list.txt",
        )

        self.test_dataset = ImageNet(
            root=self.data_dir,
            split="val",
            transform=self.test_transform,
            corruption=self.corruption,
            file_list="visual_file_list.txt",
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.train_bsz,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.val_bsz,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.val_bsz,
            num_workers=self.num_workers,
            pin_memory=True,
        )

    def get_inception_statistics(self):
        return np.load(Path(self.data_dir) / "full50k.npz")
