import torch

import wandb


def plot_fourier(fft_results, model_path=None, title="", ax=None, return_fig_ax=False):
    """
    Plot the Fourier spectrum. If ax is provided, plot on that axis; otherwise, create a new figure.
    Optionally return the figure and axis for further composition.
    """
    import matplotlib.pyplot as plt
    from matplotlib.ticker import FuncFormatter

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 5))
    else:
        fig = ax.figure

    ax.bar(torch.arange(fft_results.shape[0]), fft_results.square().sum(1).sqrt())
    ax.yaxis.set_major_formatter(FuncFormatter(_format_func))

    ax.set_xlabel("Frequency", fontsize=18)
    ax.set_ylabel("Norm of Fourier", fontsize=18)
    ax.set_title(title, fontsize=18)
    ax.tick_params(axis="x", labelsize=14)
    ax.tick_params(axis="y", labelsize=14)

    if ax is None:
        if not model_path:
            model_path = "fourier.pdf"
        plt.tight_layout()
        plt.show()
        wandb.log({"fourier": wandb.Image(plt)})
        plt.savefig(model_path, bbox_inches="tight", pad_inches=0)

    if return_fig_ax:
        return fig, ax


def _format_func(value, tick_number):
    if value == 0:
        return "0"
    elif value >= 1000:
        return f"{int(value / 1000)}k"
    elif abs(value) < 10:
        return f"{value:.1f}"
    else:
        return str(int(value))


def apply_fourier(matrix):
    """Compute the magnitude spectrum of a matrix using 1D Fast Fourier Transform.

    This function performs a 1D FFT along dimension 0 (rows) of the input matrix
    and returns the magnitude spectrum of the complex FFT coefficients.

    Args:
        matrix (torch.Tensor): Input matrix to transform

    Returns:
        torch.Tensor: Magnitude spectrum of the FFT coefficients
    """
    transformed = torch.fft.rfft(matrix, dim=0)
    magnitudes = torch.abs(transformed)
    return magnitudes
