""" A dataset reader that extracts images from folders
and splits them into train/val splits.

Author:  
"""
import os
from typing import Dict, List, Optional, Set, Tuple, Union

from timm.data.readers.reader_image_folder import ReaderImageFolder
from timm.utils.misc import natural_key
from artihippo.data.data_utils import make_val_data

class AutoSplitReaderImageFolder(ReaderImageFolder):

    def __init__(
            self,
            root,
            class_map='', 
            samples=None, 
            split=None, 
            dev_percent=None, 
            class_to_idx=None):
        super().__init__(root, class_map=class_map)

        self.split = split
        if split is not None:
            # Need to create train/val splits
            if samples is None:
                class_map = {v: k for k, v in self.class_to_idx.items()}
                assert split == "train", "Split must be train if train_samples is None. Create the train split first"
                # Create the train split
                _filenames = [sample[0] for sample in self.samples]
                _labels = [sample[1] for sample in self.samples]
                _class_names = [class_map[label] for label in _labels]
                # Split
                train_filenames, val_filenames, train_labels, val_labels, train_class_names, val_class_names = make_val_data(_filenames, _labels, _class_names, dev_percent)
                images_and_targets = [(image, label) for image, label in zip(train_filenames, train_labels)]
                # Sort
                self.train_samples = sorted(images_and_targets, key=lambda k: natural_key(k[0]))

                # Store train/val info for future use
                val_images_and_targets = [(image, label) for image, label in zip(val_filenames, val_labels)]
                self.val_samples = sorted(val_images_and_targets, key=lambda k: natural_key(k[0]))

                assert len(set(self.val_samples).union(set(self.train_samples))) == len(self.samples), "Union of train and val samples must be equal to all samples. " + \
                    f"Got val samples={len(self.val_samples)}, train samples={len(self.train_samples)}, all samples={len(self.samples)}"
                intersecting_samples = len(set(self.val_samples).intersection(set(self.train_samples)))
                assert intersecting_samples == 0, f"Intersection of train and val samples must be empty. Got {intersecting_samples} intersecting samples"

                self.samples = self.train_samples
            else:
                self.samples = samples
                self.class_to_idx = class_to_idx
