"""
functions to preprocess data for augmentation

"""

import numpy as np
import pandas as pd
from pathlib import Path
from statsmodels.tsa.stattools import acf
from scipy.signal import find_peaks
import scipy.fftpack
import random
from sklearn.preprocessing import StandardScaler
import random

# ---------------------------------------------------------------------------------------------------------------- #
# Periodicity Aware Segmentation

def estimate_period_acf(time_series):
    """
    Estimate the dominant periods in a time series based on its autocorrelation function (ACF).

    Parameters:
    time_series (array-like): The input time series data.

    Returns:
    list: Lags corresponding to the most significant peaks in the ACF, sorted by significance.
    """
    # Compute the ACF of the time series
    autocorr = acf(time_series, nlags=len(time_series) // 2, fft=True)

    # Identify peaks in the ACF, excluding lag 0
    peaks, _ = find_peaks(autocorr[1:])  # Skip lag 0
    peaks += 1  # Adjust indices since lag 0 was skipped

    # Evaluate and sort peaks by their autocorrelation values
    sorted_peaks = sorted(peaks, key=lambda x: autocorr[x], reverse=True)

    return sorted_peaks


def compute_period(time_series, min_points=1120, min_windows=10):
    """
    Computes the period of a time series using ACF as the primary method
    and Fourier Transform as a fallback. Ensures the period allows for at
    least `min_windows` windows and meets the `min_points` condition.

    Parameters:
        time_series (array-like): The input time series data.
        min_points (int): Minimum points required for each window.
        min_windows (int): Minimum number of windows required.

    Returns:
        period (int): The computed period for the time series.
        window_length (int): The calculated window length.
    """
    # Determine maximum allowable period
    max_period = max(1, (len(time_series) - min_points) // (min_windows - 1))

    # Try to find the period using ACF
    acf_periods = estimate_period_acf(time_series)
    period = max_period
    period_source = "Fallback to max_period"
    position = len(acf_periods)

    for idx, p in enumerate(acf_periods):
        if p < max_period:
            period = p
            period_source = "ACF"
            position = idx
            break

    # If no valid period is found using ACF, fallback to Fourier Transform
    if period_source != "ACF":
        freqs = scipy.fftpack.fftfreq(len(time_series), d=1)
        power_spectrum = np.abs(scipy.fftpack.fft(time_series)) ** 2
        positive_freq_indices = np.where((freqs > 0) & (1 / np.where(freqs != 0, freqs, np.inf) <= max_period))[0]

        if positive_freq_indices.size > 0:
            peak_index = positive_freq_indices[np.argmax(power_spectrum[positive_freq_indices])]
            peak_freq = freqs[peak_index]
            period = round(1 / peak_freq)
            period_source = "FFT"

    # Calculate window length based on the determined period
    N = max(1, (min_points + period - 1) // period)
    window_length = N * period

    # Diagnostics
    print(f"  - Max possible period: {max_period}")
    print(f"  - First ACF periods (ranked): {acf_periods[:(position + 1)]}")
    print(f"  - Selected period: {period} (Source: {period_source})")

    return period, window_length


def construct_windows(time_series, period, window_length=1120, min_windows=10, split='maximize-overlap', verbose=True):
    """
    Constructs windows for the time series based on the computed period.
    Ensures that the overlapping points of the last window are a multiple of the period.
    """

    # Try to compute the number of windows based on minimal overlap
    computed_windows = len(time_series) // window_length

    windows = []
    start = 0
    extra_window = False
    overlap_extra_window, periods_extra_window = None, None

    if split == 'no-periodicity':
        # Compute total span of data and spacing to get min_windows
        max_start = len(time_series) - window_length
        if min_windows <= 1 or max_start <= 0:
            windows.append(time_series[:window_length])
            overlap_step = 0
        else:
            overlap_step = max_start // (min_windows - 1)
            for i in range(min_windows):
                start = i * overlap_step
                end = start + window_length
                if end <= len(time_series):
                    windows.append(time_series[start:end])
                else:
                    # If the last window goes out of bounds, adjust
                    windows.append(time_series[-window_length:])
                    break
        total_coverage = len(windows[0]) + (len(windows) - 1) * overlap_step

    elif split == 'maximize-overlap':
        overlap_step = period
        # Generate initial windows
        while start + window_length <= len(time_series):
            windows.append(time_series[start:start + window_length])
            start += overlap_step
        total_coverage = len(windows[0]) + (len(windows) - 1) * overlap_step

    elif split == 'minimize-overlap':
        # Calculate overlap_step
        if computed_windows >= min_windows:
            number_full_periods_in_windows = window_length // period
            overlap_step = number_full_periods_in_windows * period

            # Generate initial windows
            while start + window_length <= len(time_series):
                windows.append(time_series[start:start + window_length])
                start += overlap_step
        else:
            # Calculate overlap step to get exactly min_windows
            max_start_pos = len(time_series) - window_length
            overlap_step = max(1, (max_start_pos + (min_windows - 1)) // (min_windows - 1))

            # Ensure overlap_step is a multiple of the period
            overlap_step = (overlap_step // period) * period

            # Generate windows with exact min_windows count
            windows = []
            start = 0
            for _ in range(min_windows):
                if start + window_length > len(time_series):  # Adjust last window if exceeding bounds
                    start = len(time_series) - window_length
                windows.append(time_series[start:start + window_length])
                start += overlap_step

        # Calculate total coverage and check if we can add an extra window
        total_coverage = len(windows[0]) + (len(windows) - 1) * overlap_step
        if len(time_series) - total_coverage > period:
            # Add an extra window to maximize total coverage
            extra_window = True
            end_extra_windows = total_coverage + ((len(time_series) - total_coverage) // period) * period
            windows.append(time_series[end_extra_windows - window_length:end_extra_windows])
            overlap_extra_window = len(windows[0]) - (end_extra_windows - total_coverage)
            periods_extra_window = overlap_extra_window // period
            total_coverage += len(windows[0]) - overlap_extra_window

    else:
        raise ValueError(f"Unknown SPLIT parameters. Possible values: maximize-overlap, minimize-overlap, no-periodicity")

    # Diagnostics
    if verbose:
        print(f"\n Period: {period}")
        print(f" Number of windows: {len(windows)}")
        print(f" Points in each window: {window_length}")
        print(f" Periods in a window: {window_length // period}")
        print(f" Overlapping points: {window_length - overlap_step}")
        print(f" Overlapping periods: {(window_length - overlap_step) // period}")
        if extra_window:
            print(f" Extra window overlapping points: {overlap_extra_window}")
            print(f" Extra window overlapping periods: {periods_extra_window}")
        print(f" Total coverage: {total_coverage}\n")

    return windows, window_length - overlap_step


# ---------------------------------------------------------------------------------------------------------------- #
# Data preprocessing

def interpolate_na(data):
    """
    Interpolates NaN values in each channel of the train data using linear interpolation.

    Args:
        data (np.array): 2D numpy array of shape (n_channels, n_points),
                               where each row is a time series with NaN values.

    Returns:
        np.array: 2D numpy array with NaN values interpolated.
    """
    interpolated_data = []

    for channel in data:
        # Convert to a Pandas Series to use interpolate method
        channel_series = pd.Series(channel)
        # Interpolate NaN values linearly
        channel_interpolated = channel_series.interpolate(method="linear", limit_direction="both")
        # Append the interpolated channel back to list
        interpolated_data.append(channel_interpolated.values)

    # Convert the list back to a numpy array
    return np.array(interpolated_data)


def standardscale_train_val_test(preprocessed_train, preprocessed_val = [], preprocessed_test=[]):
    # Standard Scaling the data useful for aggregated scores in evaluation

    # -- by timestamps --
    scaled_preprocessed_train = []
    scaled_preprocessed_val = []
    scaled_preprocessed_test = []

    for var in range(0,len(preprocessed_train)): #i.e. per channel
        
        # Standardscale train data
        arr_train = np.array(preprocessed_train[var])
        scaler = StandardScaler()
        scaled_arr_train = scaler.fit_transform(arr_train)
        scaled_preprocessed_train.append(scaled_arr_train)
        print('Train data standardized')
    
        # Standardscale val data
        if preprocessed_val != []:
            arr_val = np.array(preprocessed_val[var])
            scaled_arr_val = scaler.transform(arr_val)
            scaled_preprocessed_val.append(scaled_arr_val)
            print('Validation data standardized')
        else:
            print('No validation data')
        
        # Standardscale test data
        if preprocessed_test != []:
            arr_test = np.array(preprocessed_test[var])
            scaled_arr_test = scaler.transform(arr_test)
            scaled_preprocessed_test.append(scaled_arr_test)
            print('Test data standardized')
        else:
            print('No test data')

    return np.array(scaled_preprocessed_train), np.array(scaled_preprocessed_val), np.array(scaled_preprocessed_test)


# ---------------------------------------------------------------------------------------------------------------- #
# Data Upload
def get_data(data, start, end, id_channel):

    indices_channel = [data.columns.get_loc(channel) for channel in id_channel]
    data = data.T

    if start<0 or end>data.shape[1]:
        raise ValueError(f"Not enough data present!!! Trying to get data from indices:{start} to {end} but available data is {data.shape[1]} ")
    
    train_data = data.to_numpy()[indices_channel,start:end] 
    train_data = np.array(train_data, dtype=np.float32)
    train_data = interpolate_na(train_data)

    return(train_data)


def preprocess_multisample_data(data, train_channels, train_samples, train_length, create_train_val_test):

    sample_str = train_channels[0]
    sample_columns = [col for col in data.columns if col.startswith(sample_str)]

    if create_train_val_test : # add more samples for val and test
        val_samples = int(train_samples * 0.50) #50-25-25
        test_samples = int(train_samples * 0.50)
    else:
        val_samples = 0
        test_samples = 0

    if len(list(sample_columns)) <= train_samples+val_samples+test_samples:
        print('Enough samples, choosing windows in different samples')
        raise ValueError(f"Number of samples ({len(list(sample_columns))}) " f"is less than "
                        f"the total number of samples wanted ({train_samples+val_samples+test_samples}).")
    
    id_samples = random.sample(list(sample_columns), train_samples+val_samples+test_samples)
    
    start = 0
    end = train_length

    # train data
    id_train_samples = id_samples[0:train_samples]
    train_data = get_data(data, start, end, id_train_samples)
    train_data = [train_data.tolist()]
    
    # val and test data
    val_data = []
    test_data = []
    if create_train_val_test:
        id_val_samples = id_samples[train_samples:train_samples+val_samples]
        val_data = get_data(data, start, end, id_val_samples)
        val_data = [val_data.tolist()]
        id_test_samples = id_samples[train_samples+val_samples:train_samples+val_samples+test_samples]
        test_data = get_data(data, start, end, id_test_samples)
        test_data = [test_data.tolist()]

    scaled_preprocessed_train, scaled_preprocessed_val, scaled_preprocessed_test = standardscale_train_val_test(train_data, 
                                                                                                                val_data, 
                                                                                                                test_data)

    return(scaled_preprocessed_train, None, None, scaled_preprocessed_val, scaled_preprocessed_test)


def preprocess_uni_multi_variate_data(data, data_name, train_channels, train_length, 
                                  create_train_val_test, 
                                  min_windows_length, min_windows_number,train_splitting):

    min_windows_number_train = min_windows_number
    if create_train_val_test : # add more samples for val and test
        val_length = int(train_length * 0.5)
        min_windows_number_val = int(min_windows_number_train * 0.5)
        test_length = int(train_length * 0.5)
        min_windows_number_test = int(min_windows_number_train * 0.5)
    else:
        val_length = 0
        min_windows_number_val = 0 
        test_length = 0
        min_windows_number_test = 0 
    
    if data_name == 'bikesharing':
        end_test = 15800
        end_val = end_test-test_length
        end_train = end_val-val_length
        start = end_train-train_length
    elif data_name == 'etth1':
        end_test = 15000
        end_val = end_test-test_length
        end_train = end_val-val_length
        start = end_train-train_length
    else:
        start = 0
        end_train = train_length
        end_val = end_train + val_length
        end_test = end_val + test_length
    
    train_data = get_data(data, start, end_train, train_channels)
    val_data = get_data(data, end_train, end_val, train_channels)
    test_data = get_data(data, end_val, end_test, train_channels)

    if train_splitting in ['minimize-overlap', 'maximize-overlap', 'no-periodicity']:
        
        preprocessed_train = []
        preprocessed_val = []
        preprocessed_test = []

        # Iterate over each channel and calculate its period and window length
        chosen_period_per_channel = []
        for var in range(len(train_data)):
            print(f' CHANNEL: {train_channels[var]}')

            chosen_period, window_length = compute_period(train_data[var], min_points=min_windows_length, min_windows=min_windows_number_train)
            print(f'Chosen period for channel {train_channels[var]}: {chosen_period}')
            chosen_period_per_channel.append(chosen_period)
        
        chosen_period = max(chosen_period_per_channel, key=chosen_period_per_channel.count) # Chose for all channel the period that is the most present !
        print(f'Overall chosen period: {chosen_period}')

        for var in range(len(train_data)): # No real need for this loop on var as we fix a period and then just cut windows the same way for every channel !
            verbose = var < 1
            train_windows_for_var, overlap = construct_windows(train_data[var], chosen_period, min_windows_length, min_windows_number_train, train_splitting, verbose)
            preprocessed_train.append(train_windows_for_var)

            if create_train_val_test:
                val_windows_for_var, _ = construct_windows(val_data[var], chosen_period, min_windows_length, min_windows_number_val, train_splitting, verbose)
                preprocessed_val.append(val_windows_for_var)
                test_windows_for_var, _ = construct_windows(test_data[var], chosen_period, min_windows_length, min_windows_number_test, train_splitting, verbose)
                preprocessed_test.append(test_windows_for_var)
    
    else:
        raise ValueError(f"Temporary Error: Fix output of preprocessing when no train_splitting given")

    scaled_preprocessed_train, scaled_preprocessed_val, scaled_preprocessed_test = standardscale_train_val_test(preprocessed_train, 
                                                                                                                preprocessed_val, 
                                                                                                                preprocessed_test)

    return(scaled_preprocessed_train, chosen_period, overlap, scaled_preprocessed_val, scaled_preprocessed_test)



def upload_and_preprocess_data(data_name, train_channels, train_length,
                train_samples, augmentation_strategy,
                min_windows_length, min_windows_number=10, train_splitting=None, 
                create_train_val_test=False):
    
    data_path = f'data/{data_name}.csv'
    data = pd.read_csv(data_path)

    if augmentation_strategy in ('multisample'):
        train_data, period, overlap, val_data, test_data = preprocess_multisample_data(data, train_channels, train_samples, train_length, create_train_val_test)

    if augmentation_strategy in ('univariate', 'multivariate'):
        train_data, period, overlap, val_data, test_data = preprocess_uni_multi_variate_data(data, data_name, train_channels, train_length, 
                                                                            create_train_val_test, 
                                                                            min_windows_length, min_windows_number,train_splitting)
    
    return(train_data, period, overlap, val_data, test_data)
