from librosa import stft
import numpy as np
import pandas as pd
from tqdm import tqdm


def calculate_stft_magnitude(data, n_fft=16, hop_length=4, center=False):
    """
    Compute the short-time Fourier transform (STFT) of the given data.

    Args:
        data (pd.DataFrame): Input data with dimensions [time, sensor modalities].
        n_fft (int): The FFT window size.
        hop_length (int): The number of samples between successive frames.
        center (bool): Whether to pad `data` (making the STFT frames centered).

    Returns:
        data_stft (np.array): STFT coefficients with dimensions [frequency, transformed time index, sensor modalities].
                              This represents the transformation of the input data from time domain to frequency domain.
        data_stft_columns (pd.DataFrame): Corresponding column labels for the sensor modalities. These labels can be used
                                          to index into the third dimension of `data_stft`. This provides a way to manually
                                          implement NumPy-like indexing for easy data selection in Pandas.

    Detailed:
        The function processes input data by computing the STFT for each segment, converting the data from the time
        domain to the frequency domain. The resulting array `data_stft` changes the dimensions from [time, sensor modalities]
        in the input to [frequency, transformed time index, sensor modalities] in the output, reflecting the shift to
        frequency representation. The second output, `data_stft_columns`, provides the labels for each sensor modality,
        enabling indexing and selection of specific modalities via Pandas, similar to NumPy integer-based indexing.
    """
    # Use groupby for optimized segmentation
    groups = data.groupby(("segment", "", ""))
    data_stft_list = []

    skip_columns = [
        ("subject_id", "", ""),
        ("activity_id", "", ""),
        ("segment", "", ""),
    ]

    # Compute the STFT for each segment using optimized groupby
    for _, group in tqdm(groups):
        data_segment = group.loc[:, ~data.columns.isin(skip_columns)].values.T
        data_segment_stft = stft(
            y=data_segment, n_fft=n_fft, hop_length=hop_length, center=center
        )
        data_segment_stft = np.abs(data_segment_stft)

        # Handle skipped column values
        skip_columns_values = group.iloc[0, data.columns.isin(skip_columns)].values
        skip_columns_values = np.broadcast_to(
            skip_columns_values.reshape(-1, 1, 1),
            (
                len(skip_columns_values),
                data_segment_stft.shape[1],
                data_segment_stft.shape[2],
            ),
        )

        # Combine the STFT data and the skipped column values
        data_segment_stft = np.concatenate(
            [data_segment_stft, skip_columns_values], axis=0
        )
        data_stft_list.append(data_segment_stft)

    # Combine all the STFT data efficiently
    data_stft = np.concatenate(data_stft_list, axis=2)
    data_stft = np.asarray(data_stft, dtype=np.float32)
    data_stft = np.transpose(data_stft, (1, 2, 0))

    # Prepare the column labels
    data_stft_columns = pd.concat(
        [
            pd.DataFrame(data.columns[~data.columns.isin(skip_columns)]),
            pd.DataFrame(data.columns[data.columns.isin(skip_columns)]),
        ],
        axis=0,
        ignore_index=True,
    )

    return data_stft, data_stft_columns


if __name__ == "__main__":

    import os, sys
    from pathlib import Path

    srcpath = os.path.abspath(Path(os.path.dirname(__file__)) / "..")
    sys.path.insert(0, srcpath)

    from src.data.data_import import load_data

    a, b, c = load_data("pamap2")
    print(len(b.classes_))

    data = a
    n_fft = 64
    hop_length = 16
    center = True
    data_magnitude, data_magnitude_columns = calculate_stft_magnitude(
        data=data, n_fft=n_fft, hop_length=hop_length, center=center
    )
