import torch
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np

# Use LaTeX for text rendering
# (comment out if you don't have LaTeX installed)
rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
rc("text", usetex=True)

# Define input paths for the initial and final weight matrices
input_path_init = "../W1_init.pth"
input_path_final = "../W1.pth"

# Load the initial and final weight matrices
W1_init = torch.load(input_path_init)
W1_final = torch.load(input_path_final)

# Get matrix dimensions (P: number of rows, H: number of columns)
P, H = W1_init.shape

# Compute the Fourier Transform along the row dimension (axis 0)
with torch.no_grad():
    F_init = torch.fft.fft(W1_init, dim=0)  # FFT of initial weights
    F_final = torch.fft.fft(W1_final, dim=0)  # FFT of final weights


# Compute the absolute inner product between real and imaginary parts
# for the upper half of the frequency spectrum (excluding DC component)
def compute_magnitude_matrix(F):
    return [[F[i].real.dot(F[j].imag).abs() / F[i].real.norm() / F[j].imag.norm() for j in range(1, P // 2 + 1)] for i in range(1, P // 2 + 1)]


m_init = compute_magnitude_matrix(F_init)
m_final = compute_magnitude_matrix(F_final)

# Determine the global color scale to ensure both plots use the same range
vmin = 0
vmax = max(np.max(m_init), np.max(m_final))


# Create side-by-side plots and save to a single image
def save_combined_plot(m_init, m_final, filename="comparison_plot.pdf"):
    fig, axes = plt.subplots(ncols=2, figsize=(8, 3.5), gridspec_kw={"wspace": 0.1}, layout="compressed")

    custom_cmap = "viridis"  # Alternative to the standard "viridis"

    im1 = axes[0].imshow(m_init, cmap=custom_cmap, interpolation="nearest", vmin=vmin, vmax=vmax)
    axes[0].set_title("Initial ($t=0$)", fontsize=14)
    axes[0].set_xlabel("Imag($F_j$)", fontsize=12)
    axes[0].set_ylabel("Real($F_i$)", fontsize=12)
    axes[0].grid(False)

    im2 = axes[1].imshow(m_final, cmap=custom_cmap, interpolation="nearest", vmin=vmin, vmax=vmax)
    axes[1].set_title("Final ($t=5000$)", fontsize=14)
    axes[1].set_xlabel("Imag($F_j$)", fontsize=12)
    axes[1].set_ylabel("Real($F_i$)", fontsize=12)
    axes[1].grid(False)

    # Add a single colorbar for both plots
    cbar = fig.colorbar(im2, ax=axes[:], orientation="vertical", fraction=0.02, pad=0.07)

    # Cosine similarity of the Re(F_i) and Im(F_j) vectors
    cbar.ax.yaxis.tick_left()
    cbar.set_label("Cosine Similarity", fontsize=12, labelpad=10)

    plt.savefig(filename)


# Save and show the combined plot
save_combined_plot(m_init, m_final)
