from typing import Any, Callable, Dict, Optional, Tuple

from .datasetfolder import DatasetFolder
from ..types import ByteBinarySample


# ATTN: a potentially cleaner solution is to define loader/get_metadata_path as class/instance methods and require user to overwrite them.
class MalwareFolder(DatasetFolder):
    def __init__(
        self,
        root: str,
        loader: Callable[[str, str], ByteBinarySample],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
        get_metadata_path: Optional[Callable[[str], str]] = None
    ) -> None:
        if get_metadata_path is None:
            get_metadata_path = lambda path: None
        self.get_metadata_path = get_metadata_path
        super().__init__(
            root,
            lambda path: loader(path, get_metadata_path(path)),
            extensions=extensions,
            transform=transform,
            target_transform=target_transform,
            transforms=transforms,
            is_valid_file=is_valid_file,
        )

        # Clean metadata file mistakenly read as sample
        metadata = set()
        for sample, target in self.samples:
            metadata.add(get_metadata_path(sample))
        samples = []
        for sample, target in self.samples:
            if sample not in metadata:
                samples.append((sample, target))
        self.samples = samples
        self.targets = [s[1] for s in samples]
        

    def __getitem__(self, index: int) -> Tuple[ByteBinarySample, Any]:
        return super().__getitem__(index)
