import torch
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np

# Use LaTeX for text rendering
# (comment out if you don't have LaTeX installed)
rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
rc("text", usetex=True)

# Define input paths for the initial and final weight matrices
input_path_init = "../W1_init.pth"
input_path_final = "../W1.pth"

# Load the initial and final weight matrices
W1_init = torch.load(input_path_init)
W1_final = torch.load(input_path_final)

# Get matrix dimensions (P: number of rows, H: number of columns)
P, H = W1_init.shape

print(W1_final.abs().mean())

# Compute the Fourier Transform along the row dimension (axis 0)
with torch.no_grad():
    F_init = torch.fft.fft(W1_init, dim=0)  # FFT of initial weights
    F_final = torch.fft.fft(W1_final, dim=0)  # FFT of final weights


def compute_fourier_norms(F):
    return [(F[i].real.norm() ** 2, F[i].imag.norm() ** 2) for i in range(1, P // 2 + 1)]


n_init = compute_fourier_norms(F_init)
n_final = compute_fourier_norms(F_final)

# Determine the global color scale to ensure both plots use the same range
vmin = 0
vmax = max(np.max(n_init), np.max(n_final))


# Create side-by-side norm plots with non-overlapping bars and save to a single image
def save_combined_plot_norms(n_init, n_final, filename="comparison_plot_norms.pdf"):
    fig, axes = plt.subplots(ncols=2, figsize=(8, 2), gridspec_kw={"wspace": 0.1}, layout="compressed")

    indices = np.arange(len(n_init))
    width = 0.4  # Width of bars to ensure they don't overlap

    axes[0].bar(indices - width / 2, [x[0] for x in n_init], width=width, color="red", alpha=0.7, label="Real(F$_k$)")
    axes[0].bar(indices + width / 2, [x[1] for x in n_init], width=width, color="blue", alpha=0.7, label="Imag(F$_k$)")
    axes[0].set_title("Initial ($t=0$)", fontsize=16)
    axes[0].set_xlabel("Frequency ($k$)", fontsize=16)
    axes[0].set_ylabel("Squared Norm", labelpad=10, fontsize=16)
    axes[0].legend(loc="lower right", fontsize=10)
    axes[0].tick_params(axis="both", which="major", labelsize=14)

    axes[1].bar(indices - width / 2, [x[0] for x in n_final], width=width, color="red", alpha=0.7, label="Real(F$_k$)")
    axes[1].bar(indices + width / 2, [x[1] for x in n_final], width=width, color="blue", alpha=0.7, label="Imag(F$_k$)")
    axes[1].set_title("Final ($t=5000$)", fontsize=16)
    axes[1].set_xlabel("Frequency ($k$)", fontsize=16)
    axes[1].set_ylabel("Squared Norm", labelpad=10, fontsize=16)
    axes[1].legend(loc="lower right", fontsize=10)
    axes[1].tick_params(axis="both", which="major", labelsize=14)

    plt.savefig(filename)


# Save and show the combined norm plots
save_combined_plot_norms(n_init, n_final)
