import numpy as np
import scipy.fft as fft

from rayuela.occ_semiring.occ_semiring import OccurrenceWeight


def compute_powers_old(log_Z, K):
    if K == 0:
        return []

    max_log_Z = np.max(log_Z)
    stable_log_Z = log_Z - max_log_Z
    n = len(log_Z)
    n_padded = 2 ** np.ceil(np.log2(2 * n - 1)).astype(int)

    exp_stable_log_Z = np.exp(stable_log_Z)
    fft_log_Z = fft.fft(exp_stable_log_Z, n=n_padded)

    # why n-1 ???
    one = OccurrenceWeight.get_one(n - 1).x[1]
    powers = [one, log_Z.copy()]
    current_power = log_Z.copy()

    for k in range(1, K):
        stable_current_power = current_power - np.max(current_power)
        exp_stable_current_power = np.exp(stable_current_power)
        fft_current_power = fft.fft(exp_stable_current_power, n=n_padded)

        Z = fft_current_power * fft_log_Z
        ifft_result = fft.ifft(Z)
        log_result = np.log(np.abs(ifft_result)) + np.max(current_power) + max_log_Z
        current_power = log_result[: 2 * n - 1]  # Adjust to expected length

        powers.append(current_power[:n])  # Ensure powers are of length n

    return powers


def compute_powers(Z, K):
    # -1 since +1 in function...
    one = OccurrenceWeight.get_one(len(Z.x[1]) - 1)
    if K == 0:
        return one
    
    # first two powers are 1 and log_Z
    powers = [one, Z.copy()]
    current_power = Z.copy()

    for k in range(1, K):
        current_power = current_power * Z
        powers.append(current_power)

    return powers
