import pywt
import numpy as np


def subtract_wavelet_coefficients(signal1, coeffs2, args):
    wavelet = args.wavelet_base
    level = args.hyper_level

    coeffs1 = pywt.wavedec(signal1, wavelet, level=level)
    mean_coeffs1 = coeffs1[0]
    detail_coeffs1 = coeffs1[1:]

    mean_coeffs2 = coeffs2[0]
    detail_coeffs2 = coeffs2[1:]

    mean_sims = np.corrcoef(mean_coeffs1, mean_coeffs2)[0, 1]

    detail_sims = []
    for i in range(level):
        similarity = np.corrcoef(detail_coeffs1[i], detail_coeffs2[i])[0, 1]
        detail_sims.append(similarity)

    mean_coeffs_diff = mean_coeffs1 - mean_coeffs2 * args.hyper_strength * mean_sims
    detail_coeffs_diff = []
    for i in range(len(detail_coeffs1)):
        tmp_value = detail_coeffs1[i] - detail_coeffs2[i]* args.hyper_strength * detail_sims[i]
        detail_coeffs_diff.append(tmp_value)

    noise_model_diff = (mean_coeffs_diff,) + tuple(detail_coeffs_diff)

    reconstructed_signal_diff = pywt.waverec(noise_model_diff, wavelet)
    reconstructed_signal_diff_np = np.array(reconstructed_signal_diff)
    return reconstructed_signal_diff_np

def get_avg(input_sample_signals, args):
    wavelet = args.wavelet_base
    level = args.hyper_level

    mean_coeffs_list = []
    detail_coeffs_list = []

    for signal in input_sample_signals:
        signal = signal
        coeffs = pywt.wavedec(signal, wavelet, level=level)
        mean_coeffs_list.append(coeffs[0])
        detail_coeffs_list.append(coeffs[1:])

    mean_coeffs_avg = np.mean(mean_coeffs_list, axis=0)
    detail_coeffs_avg = np.mean(detail_coeffs_list, axis=0)

    model_avg_coeffs = (mean_coeffs_avg,) + tuple(detail_coeffs_avg)

    return model_avg_coeffs