# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
from pathlib import Path
import typing as t
import os
import json
import logging
from copy import deepcopy
import torch
from tqdm import tqdm
import numpy as np
import pandas as pd
from collections import defaultdict


class ResponsesLoader:
    """
    A class to load and manage a set of responses generated by foundation models.

    Attributes:
        ORDERED_ATTRIBUTES (list): The order in which attributes are expected, i.e., ["tag", "model",
                                          "dataset", "subset", "module_name", "pooling_op"]

        root (Path): A Path object representing the root directory where all responses are stored.

        attribute_values (dict): A dictionary containing unique values of various attributes in directories.

        file_format (str): Default file format for loading files ("*.pt").

    Methods:
        __init__(self, root: Path, from_folder: Path = None, file_format: str = "*.pt"):
            Initializes the ResponsesLoader object with a given root directory and optional from_folder.

        parse_folder(self, folder) : Parses the response folder structure for unique values of attributes.
    """

    ORDERED_ATTRIBUTES = [
        "tag",
        "model_name",
        "dataset",
        "subset",
        "module_names",
        "pooling_op",
    ]

    def __init__(
        self,
        root: Path,
        from_folders: t.List[Path] = None,
        file_format: str = "*.pt",
        columns: t.List[str] = None,
        label_map: t.Dict[str, int] = None,
    ):
        self.root = Path(root)
        self.attribute_values = {k: set([]) for k in self.ORDERED_ATTRIBUTES}
        self.file_format = file_format
        self.columns = columns
        self.label_map = label_map

        if self.label_map:
            logging.info(f"Using label map {self.label_map}")

        if from_folders is not None:
            self.file_trees = self.parse_folder(from_folders)
            logging.info("Parsed directory tree")
            logging.info(
                json.dumps(
                    {k: str(v) for k, v in self.attribute_values.items()},
                    indent=2,
                    sort_keys=True,
                )
            )

    def _parse_folder_tree(self, root: Path, path_parts: t.List) -> dict:
        """
        Recursively parse a folder tree and identify unique values of attributes based on ORDERED_ATTRIBUTES order.

        Args:
            root (Path): The directory to start parsing from.

            path_parts (list): A list of parts forming the path in string format, representing subdirectories
                            to be parsed relative to 'root'.

        Returns:
            dict: A dictionary containing paths for files matching file_format and unique values of attributes.

        Side effects:
            Updates attribute_values with unique names found within directories matched by path_parts.
        """
        if len(path_parts) == 0:
            return list(root.glob(self.file_format))
        ret = {}
        attribute_name = self.ORDERED_ATTRIBUTES[-len(path_parts)]
        current_part = path_parts[0]
        part_paths = list(root.glob(current_part))
        for part_path in part_paths:
            if os.path.isdir(part_path):
                self.attribute_values[attribute_name].add(part_path.name)
                ret[str(part_path.name)] = self._parse_folder_tree(
                    part_path, path_parts[1:]
                )
        return ret

    def _filter_tree(self, tree, level: int = 0, filter: t.Dict = {}):
        if level == len(self.ORDERED_ATTRIBUTES):
            return tree
        new_tree = {}
        attribute_name = self.ORDERED_ATTRIBUTES[level]
        filtered_keys = filter.get(attribute_name, list(tree.keys()))
        if not isinstance(filtered_keys, (list, tuple)):
            filtered_keys = [filtered_keys]
        assert len(filtered_keys) > 0
        for key in filtered_keys:
            new_tree[key] = self._filter_tree(tree[key], level + 1, filter=filter)
        return new_tree

    def _flatten_tree(
        self, tree: dict, level: int = 0, attributes: dict = {}
    ) -> t.List[dict]:
        """
        Filter a parsed folder tree based on provided attributes.

        Args:
            tree (dict): A dictionary representing the parsed folder tree to be filtered.

            level (int): The current depth of recursive parsing or filtering, 0 by default.

            filter (dict): An optional dictionary specifying which keys in `tree` should be included based on attributes.
                        If an attribute is not specified in the filter, all unique values for that attribute are used.
                        E.g., {"subset": ["dogs", "cats"], "pooling_op": "mean"} will only include these two data subsets even if more exist.

        Returns:
            dict: A filtered version of the input tree, with only nodes matching provided filters retained.

        Note:
            The filtering operation applies to the `ORDERED_ATTRIBUTES` order, and includes keys in the output if they match
            any of the specified values for that attribute (or all unique values if no filter is provided).
        """
        ret = []
        if level == len(self.ORDERED_ATTRIBUTES):
            assert isinstance(tree, (list, tuple))
            for v in tree:
                new_attributes = deepcopy(attributes)
                new_attributes["path"] = v
                ret.append(new_attributes)
            return ret
        for k, v in tree.items():
            attribute_name = self.ORDERED_ATTRIBUTES[level]
            new_attributes = deepcopy(attributes)
            new_attributes[attribute_name] = k
            ret.extend(self._flatten_tree(v, level + 1, new_attributes))
        return ret

    def get_attribute_values(
        self, attribute_name: str, filter: t.Optional[t.List[str]] = None
    ) -> set:
        """
        Retrieve unique values of a specific attribute across all parsed directories.

        Args:
            attribute_name (str): The name of the attribute to retrieve unique values for.
            filter (list): Return only those values that match the filters (glob mode).
        Returns:
            set: A set containing unique values for the specified attribute in all parsed directories.

        Raises:
            AssertionError: If `attribute_name` is not one of the ORDERED_ATTRIBUTES.

        Note:
            The uniqueness of attribute values across all parsed directories are stored and maintained
            in the `attribute_values` dictionary during parsing operations.
        """
        import fnmatch

        assert attribute_name in self.ORDERED_ATTRIBUTES
        filter = filter or ["*"]
        ret = set()
        for f in filter:
            matching = fnmatch.filter(self.attribute_values[attribute_name], f)
            ret.update(matching)
        return ret

    def parse_folder(self, folders: t.List[Path]) -> t.List[dict]:
        """
        Parse a response folder structure for unique values of attributes and prepare it for data loading.

        Args:
            folder (str): The a list of directories to start parsing from as Paths / strings.

        Returns:
            dict: A dictionary representing the parsed folder tree, with paths for files matching
                file_format and unique values of attributes.

        Side effects:
            Updates attribute_values with unique names found within directories in 'folder'.

        Note:
            The parsing operation applies to the `ORDERED_ATTRIBUTES` order, and stores all unique
            combinations of the first len(ORDERED_ATTRIBUTES) parts of paths.
        """
        ret = []
        if not isinstance(folders, (list, tuple)):
            folders = [folders]
        folders = map(Path, folders)
        for folder in folders:
            folder_parts = folder.parts
            while len(folder_parts) < len(self.ORDERED_ATTRIBUTES):
                folder = folder / "*"
                folder_parts = folder.parts

            ret.append(self._parse_folder_tree(self.root, list(folder_parts)))
        return ret

    def load_data_subset(
        self,
        attribute_values: t.Optional[t.Dict] = None,
        batch_size: int = 128,
        num_workers: int = 0,
    ) -> dict:
        """
        Load a subset of responses based on provided attributes and prepare it for further processing.

        Args:
            attribute_values (dict): An optional dictionary specifying which keys in the filtered tree
                                    should be included as per their respective attributes. If an attribute
                                    is not specified, all unique values for that attribute are used.

            batch_size (int): The number of samples to load at a time, default is 128.

            num_workers (int): The number of subprocesses to use for data loading, default is 0.

        Returns:
            dict: A dictionary containing loaded responses, grouped by attributes and batch-wise. Each key in the
                returned dictionary corresponds to an attribute from `ORDERED_ATTRIBUTES`. The values are lists of
                tensors corresponding to each unique combination of attribute values across batches.

        Note:
            Loading operation applies a filter based on provided `attribute_values` and loads responses in batch-wise,
            allowing for parallelized data loading with the help of PyTorch's DataLoader. The returned dictionary can
            be used directly for further processing or analysis tasks.

        Side effects:
            Updates attribute_values based on provided filters if any.
        """
        torch.multiprocessing.set_sharing_strategy("file_system")
        logging.info(
            f"Loading data subset: {str(attribute_values) if len(attribute_values) > 0 else 'all_data'}"
        )
        # Making sure the passed keys are valid keys according to the current schema.
        assert all(
            [k in self.ORDERED_ATTRIBUTES for k in attribute_values.keys()]
        ), f"Some of the keys in {attribute_values.keys()} are not valid.\nValid keys are {self.ORDERED_ATTRIBUTES}."
        # No None values for num_workers
        num_workers = num_workers or 0

        flattened_tree = []
        for tree in self.file_trees:
            filtered_tree = self._filter_tree(tree, filter=attribute_values)
            flattened_tree.extend(self._flatten_tree(filtered_tree))

        self.dataset = ResponsesDataset(
            flattened_tree, columns=self.columns, label_map=self.label_map
        )

        loader = torch.utils.data.DataLoader(
            self.dataset,
            shuffle=False,
            batch_size=batch_size,
            num_workers=num_workers,
            drop_last=False,
        )

        sample = self.dataset[0]
        data = {k: [] for k in sample.keys()}
        for batch in tqdm(loader, disable=(logging.getLogger().level > logging.INFO)):
            for k, v in batch.items():
                (
                    data[k].append(v)
                    if isinstance(v[0], torch.Tensor)
                    else data[k].extend(v)
                )

        for k, v in data.items():
            if isinstance(v[0], torch.Tensor):
                data[k] = torch.cat(v).to(torch.float32).numpy()
        return data


class ResponsesDataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset that loads responses data from a list of dictionaries
    containing metadata and file paths. The data is loaded using torch.load,
    converted to numpy array for 'responses' field, and may include an additional
    'unit' number along with the response data. It also allows selection of specific
    unit for processing if needed. This dataset also uses a 'label_map' dictionary
    to map subsets to labels. The label is added to each sample in the dataset.

    Attributes:
        records (list): A list of dictionaries containing metadata about each record
                        including file path to load data from.
        add_unit_number (bool): If True, adds 'unit' number to the loaded data.
        select_unit (int or None): If not None, only loads responses for this unit
                                    number if it exists in the data. Defaults to None.
        columns (list): A list of column names to include in the returned dictionary
                        from each sample. The default is None, which means all available
                        columns will be included.
        label_map (dict): A dictionary mapping subsets to labels. This dictionary is used
                          to add a 'label' key-value pair to each sample in the dataset.
    """

    def __init__(
        self,
        records: t.List[dict],
        add_unit_number: bool = False,
        select_unit: int = None,
        columns: t.List[str] = None,
        label_map: dict = None,
    ):
        self.records = records
        self.add_unit_number = add_unit_number
        self.select_unit = select_unit
        self.columns = columns
        self.label_map = label_map
        assert label_map is not None or "label" not in columns
        self.records = [i for i in self.records if i["subset"] in self.label_map.keys()]

    def __getitem__(self, item) -> dict:
        """
        Get a data record at the specified index.

        Args:
            item (int): The index of the record to retrieve.

        Returns:
            dict: A dictionary representing the data record, with keys and values for "responses" and "unit".
                  If `add_unit_number` is True, then "unit" will be an array of unit numbers corresponding to each response in "responses".
        """
        meta = self.records[item]
        path = meta["path"]
        datum = torch.load(path)
        datum["responses"] = datum["responses"]
        if self.add_unit_number:
            datum["unit"] = np.arange(datum["responses"].shape[-1])
            if self.select_unit is not None:
                datum["unit"] = datum["unit"][self.select_unit]
        if self.select_unit is not None:
            datum["responses"] = datum["responses"][..., self.select_unit]
        datum.update(meta)
        del datum["path"]
        if self.columns:
            if "label" in self.columns:
                datum["label"] = self.label_map[datum["subset"]]
            datum = {k: datum[k] for k in self.columns}
        return datum

    def __len__(self) -> int:
        """
        Get the total number of data records in this dataset.

        Returns:
            int: The length (total number of data records) in this dataset.
        """
        return len(self.records)
