import torch
import os
import numpy as np


class CachedFeatureLoader:
    def __init__(self, dataset: str, path_to_cache_dir: str, feature_extractor: str):
        self.dataset = dataset
        self.path_to_cache_dir = path_to_cache_dir
        self.feature_extractor = feature_extractor

    def _load_complete_data_from_disk(self, train: True):
        split = "train" if train else "test"
        amount_of_data = "-1"
        features = torch.load(
            os.path.join(
                self.path_to_cache_dir, self.dataset, f"features_{split}_{amount_of_data}_{self.feature_extractor}.pt"
            )
        )
        labels = torch.load(
            os.path.join(
                self.path_to_cache_dir, self.dataset, f"labels_{split}_{amount_of_data}_{self.feature_extractor}.pt"
            )
        )

        # load validation data for dtd and flowers and add to training data
        if train and self.dataset in ["dtd", "oxford_flowers102", "diabetic_retinopathy_detection"]:
            split = "validation"
            features_val = torch.load(
                os.path.join(
                    self.path_to_cache_dir,
                    self.dataset,
                    f"features_{split}_{amount_of_data}_{self.feature_extractor}.pt",
                )
            )
            labels_val = torch.load(
                os.path.join(
                    self.path_to_cache_dir,
                    self.dataset,
                    f"labels_{split}_{amount_of_data}_{self.feature_extractor}.pt",
                )
            )
            features = torch.cat([features, features_val], dim=0)
            labels = torch.cat([labels, labels_val], dim=0)
        return features, labels

    def _subsample_classes(self, labels: torch.Tensor, n_of_classes_to_select: int, sampling_method: str):
        n_available_classes = len(torch.unique(labels))

        # input check
        if n_available_classes < n_of_classes_to_select:
            raise ValueError("There are not enough classes available to select.")
        if n_of_classes_to_select == 0 or n_of_classes_to_select < -1:
            raise ValueError("No classes selected")

        # just return all classes
        if n_of_classes_to_select == n_available_classes or n_of_classes_to_select == -1:
            return np.array(range(0, n_available_classes))

        if sampling_method == "random":
            selected_classes = np.random.choice(
                list(range(0, n_available_classes)), size=n_of_classes_to_select, replace=False
            )

        assert n_of_classes_to_select == len(selected_classes)
        return np.array(selected_classes)

    def _subsample_shots(self, features: torch.Tensor, labels: torch.Tensor, classes: list, shots: int):
        if not (shots != -1 or shots < 1 or shots * len(classes) > len(labels)):
            raise ValueError("The number of shots is not appropriate for the dataset.")

        selected_element_list = list()
        for c in classes:
            # get all elements of class
            selected_indicies = np.array(labels == c)

            if shots != -1:
                # select only shots from class (number of samples)
                x = np.flatnonzero(selected_indicies)
                selected_indicies[np.random.choice(x, len(x) - shots, replace=0)] = False

            selected_element_list.append(selected_indicies)

        selected_elements = np.array(selected_element_list)
        selected_elements = np.sum(selected_elements, axis=0).astype(dtype=bool)

        selected_features = features[selected_elements, :]
        selected_labels = labels[selected_elements]

        # map labels to interval without gaps (e.g., 1, 2 instead of 20, 50)
        class_mapping = dict()
        for c_i, c in enumerate(sorted(torch.unique(selected_labels))):
            selected_labels[selected_labels == c] = c_i
            original_class_name = str(c.item())
            class_mapping[original_class_name] = c_i

        if shots != -1:
            assert len(selected_labels) == len(classes) * shots
            assert len(selected_features) == len(classes) * shots

        # the opposite check is quite tricky with unbalanced classes
        return selected_features, selected_labels, selected_elements, class_mapping

    def obtain_feature_dim(self):
        """Returns the feature dimension of the cached data."""
        all_test_features, _ = self._load_complete_data_from_disk(train=True)
        return all_test_features.shape[1]

    def load_train_data(self, shots: int, n_classes: int, seed: int):
        """
        Loads training data based on shots and n_classes.
        The sampling is done at random.
        """
        np.random.seed(seed)
        all_train_features, all_train_labels = self._load_complete_data_from_disk(train=True)

        selected_classes = self._subsample_classes(all_train_labels, n_classes, sampling_method="random")
        train_features, train_labels, data_indicies, class_mapping = self._subsample_shots(
            all_train_features, all_train_labels, selected_classes, shots
        )

        return train_features, train_labels, data_indicies, class_mapping

    def load_test_data(self, class_mapping=None):
        """
        Loads test data based.
        It is possible to pass an array of classes to select datafrom.
        """
        all_test_features, all_test_labels = self._load_complete_data_from_disk(train=False)
        if class_mapping is not None:
            selected_element_list = list()
            for c in class_mapping.keys():
                # get all elements of class
                selected_indicies = np.array(all_test_labels == float(c))
                selected_element_list.append(selected_indicies)

            selected_elements = np.array(selected_element_list)
            selected_elements = np.sum(selected_elements, axis=0).astype(dtype=bool)

            selected_test_features = all_test_features[selected_elements, :]
            selected_test_labels = all_test_labels[selected_elements]

            # use the correct label (the same mapping as in the selection of the training)
            for c in sorted(torch.unique(selected_test_labels)):
                selected_test_labels[selected_test_labels == c] = class_mapping.get(str(c.item()))

            assert len(torch.unique(selected_test_labels)) == len(list(class_mapping.keys()))
            assert sorted(torch.unique(selected_test_labels)) == sorted(class_mapping.values())
            return selected_test_features, selected_test_labels
        else:
            return all_test_features, all_test_labels
