# -*- coding: utf-8 -*-
import pandas as pd
from tqdm import tqdm
import numpy as np
import scipy
from scipy.signal import find_peaks


def find_peak_intervals(signal):
    # find local peaks
    peaks, _ = find_peaks(signal)
    # compute the intervals between adjacent peaks
    peak_intervals = np.diff(peaks)
    return peak_intervals, peaks

def find_significant_periods_peak_detection(data, num_periods=10):
    n_samples, n_dims = data.shape
    all_intervals = []

    for i in range(n_dims):
        signal = data[:, i]
        # find peak intervals
        intervals, peaks = find_peak_intervals(signal)
        all_intervals.extend(intervals)
    # count the frequency of all intervals
    interval_counts = np.bincount(all_intervals)
    # find the most significant periods (intervals with the highest frequency)
    significant_intervals = np.argsort(-interval_counts)[:num_periods]
    return significant_intervals, interval_counts[significant_intervals]

def autocorrelation(signal):
    result = scipy.signal.correlate(signal, signal, mode='full')
    return result[result.size // 2:]

def find_significant_periods_autocorr(data, num_periods=10, max_lag=None):
    n_samples, n_dims = data.shape
    period_strength = {}

    for i in tqdm(range(n_dims)):
        signal = data[:, i]
        # compute the autocorrelation function
        autocorr = autocorrelation(signal)
        if max_lag is not None:
            autocorr = autocorr[:max_lag]
        # find local peaks of the autocorrelation function as candidate periods
        peaks = np.diff(np.sign(np.diff(autocorr))) < 0
        peak_indices = np.where(peaks)[0] + 1

        for idx in peak_indices:
            period = idx
            strength = autocorr[idx]
            if period not in period_strength:
                period_strength[period] = strength
            else:
                period_strength[period] += strength
    # sort the periods by the autocorrelation strength and select the most significant ones
    period_strength = sorted(period_strength.items(), key=lambda x: x[1], reverse=True)
    significant_periods = [int(x[0]) for x in period_strength[:num_periods]]
    significant_period_strength = [float(x[1]) for x in period_strength[:num_periods]]

    return significant_periods, significant_period_strength


data_dict = {
    "ETTh1": "./dataset/ETT-small/ETTh1.csv",
    "ETTh2": "./dataset/ETT-small/ETTh2.csv",
    "ETTm1": "./dataset/ETT-small/ETTm1.csv",
    "ETTm2": "./dataset/ETT-small/ETTm2.csv",
    "weather": "./dataset/weather/weather.csv",
    "electricity": "./dataset/electricity/electricity.csv",
    "traffic": "./dataset/traffic/traffic.csv",
}


if __name__ == '__main__':
    import json
    # max_lag = seq_len // 2
    for max_lag in [48, 168, 256]:
        print(f"max_lag {max_lag} for seq_len {max_lag * 2}")
        with open(f"./autocorr_{max_lag}.json", "w") as f:
            output = {}
            for data_name, file_path in data_dict.items():
                data = pd.read_csv(file_path, parse_dates=["date"])
                data = data.drop(columns=["date"])
                data_norm = (data - data.mean()) / data.std()
                significant_periods, significant_period_strength = find_significant_periods_autocorr(data_norm.values, 20, max_lag=max_lag)
                output[data_name] = {
                    "significant_periods": significant_periods,
                    "period_strength": significant_period_strength,
                }
            json.dump(output, f, indent=4)
