#
# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables
# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+)
# Source code:
# https://github.com/AdaptiveMotorControlLab/CEBRA
#
# Please see LICENSE.md for the full license document:
# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Helper functions for training embeddings on DeepLabCut outputs."""

import pathlib
import warnings
from typing import List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt

_IS_PANDAS_AVAILABLE = True
try:
    import pandas as pd
except ModuleNotFoundError:
    _IS_PANDAS_AVAILABLE = False
    warnings.warn(
        "pandas module was not found, be sure it is installed in your env.",
        ImportWarning,
    )


class _DLCLoader:
    """Helper class to easily load HDF5 outputs from DeepLabCut as a :py:func:`numpy.array`.

    Args:
        dlc_filepath: The path to the ``.h5`` DLC output file.
        keypoints: A list of keypoints, corresponding to the bodypoints columns generated by DLC.

    Example:

        >>> import cebra
        >>> import cebra.integrations.deeplabcut as cebra_dlc
        >>> import cebra.helper as cebra_helper
        >>> url = ANNOTATED_DLC_URL = "https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/Reaching-Mackenzie-2018-08-30/labeled-data/reachingvideo1/CollectedData_Mackenzie.h5?raw=true"
        >>> file = cebra_helper.download_file_from_url(url) # an .h5 example file
        >>> # By default, all bodyparts are selected ...
        >>> full_data = cebra_dlc._DLCLoader(dlc_filepath=file).load_data()
        >>> # ... but keypoints of interest can be provided too!
        >>> core_data = cebra_dlc._DLCLoader(dlc_filepath=file, keypoints=["Hand", "Joystick1"]).load_data()

    """

    def __init__(
        self,
        dlc_filepath: Union[pathlib.Path, str],
        keypoints: Optional[Union[list, str]] = None,
    ):
        self.dlc_df, self.keypoints_list, self.scorer = self.read_dlc(
            dlc_filepath, keypoints=keypoints)

    def read_dlc(self, filepath: Union[pathlib.Path, str],
                 keypoints: list) -> Tuple[pd.DataFrame, List[str], str]:
        """Read a DLC file and extract df, bodyparts and scorer.

        See also:
            Inspired from `dlc2kinematics loading method <https://github.com/AdaptiveMotorControlLab/DLC2Kinematics/blob/82e7e60e00e0efb3c51e024c05a5640c91032026/src/dlc2kinematics/preprocess.py#L12>`_.

        Args:
            dlc_filepath: The path to the ``.h5`` DLC output file.
            keypoints: A list of keypoints, corresponding to the bodypoints columns generated by DLC.

        Returns:
            The dataframe generated by DLC and containing the x and y positions at each timestep for each keypoints, as well as the likelihood
            of those positions (first return), the tracked bodyparts (second return) and the scorer used (third return).

        """
        if _IS_PANDAS_AVAILABLE:
            try:
                df = pd.read_hdf(filepath, "df_with_missing")
            except KeyError:
                df = pd.read_hdf(filepath)
        else:
            raise ModuleNotFoundError("pandas could not be imported.")
        scorer = df.columns.get_level_values("scorer")[0]
        keypoints = self.initialize_keypoints_list(df, keypoints)
        return df, keypoints, scorer

    def initialize_keypoints_list(self, df: pd.DataFrame,
                                  keypoints: Optional[List[str]]) -> List[str]:
        """Initialize the list of keypoints to keep from ``df``.

        If ``keypoints`` is None, then all available keypoints bodyparts from ``df`` are kept. If ``keypoints`` is
        provided, ``keypoints`` is returned, after each keypoint of the list is checked to be contained in ``df``.

        Args:
            df: The DLC output dataframe.
            keypoints: The keypoints to keep in the data.

        Returns:
            A list of keypoints to keep.

        """
        bodyparts = df.columns.get_level_values("bodyparts").unique().to_list()
        if keypoints is None:
            keypoints_list = bodyparts
        elif isinstance(keypoints, list):
            self.check_valid_keypoints(keypoints, bodyparts)
            keypoints_list = keypoints
        else:
            raise ValueError(
                f"Invalid value for keypoints: expected a list of str corresponding to the bodyparts from your DLC dataframe, got {keypoints}"
            )
        return keypoints_list

    def check_valid_keypoints(self, keypoints: List[str],
                              valid_keypoints: List[str]):
        """Check if the keypoints in the ``keypoints`` list are present in ``valid_bodyparts``.

        Args:
            keypoints: A list of strings to check if they are present in the ``valid_bodyparts`` list.
            valid_bodyparts: A list of keypoints present in the instance's DLC dataframe.

        Raises:
            AttributeError: if one of the keypoints is not in the valid list of keypoints.
        """
        for keypoint in keypoints:
            if keypoint not in valid_keypoints:
                raise AttributeError(
                    f"Invalid bodypart: got {keypoint}, please provide a list of bodyparts present in {valid_keypoints}."
                )

    def load_data(self, pcutoff: float = 0.6) -> npt.NDArray:
        """Get the data from ``dlc_df``, check for likelihood of the position
            and interpolate the NaNs values.

        Args:
            pcutoff: Drop-out threshold. If the likelihood value on the estimated positions a sample is
                smaller than that threshold, then the sample is set to nan. Then, the nan values are
                interpolated.

        Returns:
            A 2D array containing the interpolated and selected data from DLC.

        """

        if self.dlc_df.columns.nlevels < 3 or self.dlc_df.columns.nlevels > 4:
            raise ValueError(
                f"Invlalid DLC file, expects 3 columns indexes: scorer, bodyparts and coords,"
                f"got {self.dlc_df.columns.nlevels} columns: {self.dlc_df.columns.names}"
            )
        elif self.dlc_df.columns.nlevels == 4:
            raise NotImplementedError(
                f"Multi-animals DLC files are not handled. Please provide a single-animal file."
            )

        dlc_df_coords = (
            self.dlc_df.columns.get_level_values("coords").unique().to_list())

        pred_xy = []
        for i, _ in enumerate(self.dlc_df.index):
            data = (self.dlc_df.iloc[i].loc[self.scorer].loc[
                self.keypoints_list].to_numpy().reshape(-1, len(dlc_df_coords)))

            # Handles nan values with interpolation
            if i > 0 and i < len(self.dlc_df) - 1:
                if len(dlc_df_coords) == 2:
                    pre_ = (self.dlc_df.iloc[i - 1].loc[self.scorer].loc[
                        self.keypoints_list].to_numpy().reshape(
                            -1, len(dlc_df_coords)))
                    next_ = (self.dlc_df.iloc[i + 1].loc[self.scorer].loc[
                        self.keypoints_list].to_numpy().reshape(
                            -1, len(dlc_df_coords)))
                else:
                    pre_ = (self.dlc_df.iloc[i - 1].loc[self.scorer].loc[
                        self.keypoints_list].to_numpy().reshape(
                            -1, len(dlc_df_coords))[:, :2])
                    next_ = (self.dlc_df.iloc[i + 1].loc[self.scorer].loc[
                        self.keypoints_list].to_numpy().reshape(
                            -1, len(dlc_df_coords))[:, :2])

                concat_ = np.concatenate(
                    [pre_.reshape(1, -1, 2),
                     next_.reshape(1, -1, 2)], axis=0)
                median = np.median(concat_, axis=0)

                if "likelihood" in dlc_df_coords and len(dlc_df_coords) > 2:
                    data[data[:, 2] < pcutoff] = np.nan
                    data = data[:, :2]
                nan_indices = np.argwhere(np.isnan(data))
                data[nan_indices] = median[nan_indices]

            if len(dlc_df_coords) > 2:
                data = data[:, :2]
            pred_xy.append(data.reshape(1, -1, 2))

        array = np.concatenate(pred_xy, axis=0).reshape((len(self.dlc_df), -1))
        return array[~np.isnan(array).any(
            axis=1)]  # remove rows with remaining NaNs

    # NOTE(celia): dlc2kinematics integration, to preprocess DLC prediction data and
    #              compute kinematic features to be integrated to the dataset.

    # def compute_joint_features(self):

    #     # Define joints: core keypoints model
    #     joints_dict={}
    #     joints_dict['R-side']  = ['nose', 'rightearbase', 'tailbase']
    #     joints_dict['L-side']  = ['nose', 'leftearbase', 'tailbase']

    #     joint_angles = dlc2kinematics.compute_joint_angles(self.dlc_df, joints_dict,
    #                                           dropnan=True, smooth=True, save=True)
    #     joint_vel = dlc2kinematics.compute_joint_velocity(joint_angles)
    #     corr = dlc2kinematics.compute_correlation(joint_vel, plot=True,colormap='viridis')

    #     return joint_angles, joint_vel, corr

    # def load_single_index_dataframe(self, filepath: Union[pathlib.Path, str]):
    #     df = pd.read_hdf(filepath)
    #     pred_xy = []
    #     for _, file_name in enumerate(df.index):
    #         data = df.loc[file_name]
    #         kpts = data.to_numpy().reshape((-1,2))[:,:2]
    #         kpts = np.expand_dims(kpts, axis = 0)
    #         pred_xy.append(kpts)
    #     return np.concatenate(pred_xy, axis = 0).reshape((len(df), -1))


def load_deeplabcut(
    filepath: Union[pathlib.Path, str],
    keypoints: Optional[list] = None,
    pcutoff: float = 0.6,
) -> npt.NDArray:
    """Load DLC data from h5 files.

    Args:
        filepath: Path to the ``.h5`` file containing DLC output data.
        keypoints: List of keypoints to keep in the output ``numpy.array``.
        pcutoff: Drop-out threshold. If the likelihood value on the estimated positions a sample is
            smaller than that threshold, then the sample is set to nan. Then, the nan values are
            interpolated.

    Returns:
        A 2D array (``n_samples x n_features``) containing the data (``x`` and ``y``) generated
        by DLC for each keypoint of interest. Note that the ``likelihood`` is dropped.

    Example:

        >>> import cebra
        >>> url = ANNOTATED_DLC_URL = "https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/Reaching-Mackenzie-2018-08-30/labeled-data/reachingvideo1/CollectedData_Mackenzie.h5?raw=true"
        >>> file = cebra.helper.download_file_from_url(url) # an .h5 example file
        >>> dlc_data = cebra.load_deeplabcut(file, keypoints=["Hand", "Joystick1"], pcutoff=0.6)

    """
    return _DLCLoader(dlc_filepath=filepath,
                      keypoints=keypoints).load_data(pcutoff=pcutoff)
