from typing import Any
from collections.abc import Callable, Sequence
from pathlib import Path

from PIL import Image
import numpy as np

import torch
from torch import Tensor
from torch.utils.data import Dataset, Subset



def load_image(p: Path) -> Image.Image:
    with p.open("rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


class ImageDataset(Dataset):
    def __len__(self) -> int:
        raise NotImplementedError

    def __getitem__(self, i: int) -> tuple[Any, float]:
        raise NotImplementedError


class ImageTransformDataset(ImageDataset):
    def __init__(self,
                 dataset: ImageDataset,
                 transform: Callable[[Image.Image], Tensor]):
        self.dataset = dataset
        self.transform = transform

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, i: int) -> tuple[Tensor, float]:
        img, label = self.dataset[i]
        img = self.transform(img)
        return img, label


class ImageTwoTransformDataset(ImageDataset):
    def __init__(self,
                 dataset: ImageDataset,
                 transform_weak: Callable[[Image.Image], Tensor],
                 transform_strong: Callable[[Image.Image], Tensor]):
        self.dataset = dataset
        self.transform_weak = transform_weak
        self.transform_strong = transform_strong

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, i: int) -> tuple[Tensor, float]:
        img, label = self.dataset[i]
        img_weak = self.transform_weak(img)
        img_strong = self.transform_strong(img)
        
        return img_weak, img_strong, label


class ImageSubset(ImageDataset):
    def __init__(self, dataset: ImageDataset, indices: Sequence[int]):
        self.dataset = Subset(dataset, indices)

    def __getitem__(self, i: int) -> tuple[Any, float]:
        return self.dataset[i]

    def __len__(self) -> int:
        return len(self.dataset)
    
    
class IndexedDataset(Dataset):
    """adding index onto original dataset"""
    def __init__(self, original_dataset):
        self.dataset = original_dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx: int):
        data = self.dataset[idx]
        # check form
        if isinstance(data, tuple):
            return data + (idx,)
        elif isinstance(data, dict):
            data['index'] = idx
            return data
        else:
            return (data, idx)



# add seed based on SSA: for reproducibility
def random_split(dataset: ImageDataset, n1: int, seed: int = 42) -> tuple[ImageSubset, ImageSubset]:
    rng = np.random.RandomState(seed)
    perm = rng.permutation(len(dataset)).tolist()
    
    return ImageSubset(dataset, perm[:n1]), ImageSubset(dataset, perm[n1:])





