import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import numpy as np
import seaborn as sns
from matplotlib.colors import Normalize
import torch


# sphinx_gallery_thumbnail_number = 2


def heatmap(data, row_labels, col_labels, ax=None,
            cbar_kw={}, cbarlabel="", cbar_shrink=0.25, fontsize=32, **kwargs):
    """
    Create a heatmap from a numpy array and two lists of labels.

    Arguments:
        data       : A 2D numpy array of shape (N,M)
        row_labels : A list or array of length N with the labels
                     for the rows
        col_labels : A list or array of length M with the labels
                     for the columns
    Optional arguments:
        ax         : A matplotlib.axes.Axes instance to which the heatmap
                     is plotted. If not provided, use current axes or
                     create a new one.
        cbar_kw    : A dictionary with arguments to
                     :meth:`matplotlib.Figure.colorbar`.
        cbarlabel  : The label for the colorbar
    All other arguments are directly passed on to the imshow call.
    """

    if not ax:
        ax = plt.gca()

    # Plot the heatmap
    im = ax.imshow(data, **kwargs)

    # Create colorbar
    # pad_fraction = 0.1
    # aspect = 20
    # divider = make_axes_locatable(ax)
    # width = axes_size.AxesY(ax, aspect=1. / aspect)
    # pad = axes_size.Fraction(pad_fraction, width)
    # cax = divider.append_axes("right", size=width, pad=pad)

    cbar = ax.figure.colorbar(im, shrink=cbar_shrink, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom", fontsize=fontsize)

    # We want to show all ticks...
    ax.set_xticks(np.arange(data.shape[1]))
    ax.set_yticks(np.arange(data.shape[0]))

    ax.set_xticklabels([])  # Remove x-axis labels
    ax.set_yticklabels([])  # Remove y-axis labels
    ax.tick_params(axis='both', which='both', length=0)  # Remove tick marks

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=-45, ha="right",
             rotation_mode="anchor", fontsize=fontsize)

    # Turn spines off and create white grid.
    for edge, spine in ax.spines.items():
        spine.set_visible(False)

    ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True)
    ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar


def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
                     textcolors=["black", "white"],
                     threshold=None, **textkw):
    """
    A function to annotate a heatmap.

    Arguments:
        im         : The AxesImage to be labeled.
    Optional arguments:
        data       : Data used to annotate. If None, the image's data is used.
        valfmt     : The format of the annotations inside the heatmap.
                     This should either use the string format method, e.g.
                     "$ {x:.2f}", or be a :class:`matplotlib.ticker.Formatter`.
        textcolors : A list or array of two color specifications. The first is
                     used for values below a threshold, the second for those
                     above.
        threshold  : Value in data units according to which the colors from
                     textcolors are applied. If None (the default) uses the
                     middle of the colormap as separation.

    Further arguments are passed on to the created text labels.
    """

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # Normalize the threshold to the images color range.
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max()) / 2.

    # Set default alignment to center, but allow it to be
    # overwritten by textkw.
    kw = dict(horizontalalignment="center",
              verticalalignment="center")
    kw.update(textkw)

    # Get the formatter in case a string is supplied
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[im.norm(data[i, j]) > threshold])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

def plot_heatmaps(original_weight, original_weight1, lora_B, lora_A, semsvd_B, semsvd_A, title="Comparison of LoRA and Semantic-Enhanced SVD Initialization Methods (512×512 Matrix, Rank=8)"):
    # Calculate residual matrices
    lora_reconstruction = lora_B @ lora_A
    semsvd_reconstruction = semsvd_B @ semsvd_A

    # lora_residual = original_weight - lora_reconstruction
    # semsvd_residual = original_weight1 - semsvd_reconstruction
    
    # Draw LoRA initialized reconstruction matrix
    fig1, ax1 = plt.subplots(figsize=(36, 24))
    cbar = sns.heatmap(lora_reconstruction.detach().cpu().numpy(), ax=ax1, cmap="RdBu", cbar=True, square=True,
                xticklabels=100, yticklabels=100)
    cbar_ax = cbar.collections[0].colorbar
    cbar_ax.ax.tick_params(labelsize=36)
    # Set 5 legend ticks
    cbar_ax.ax.yaxis.set_major_locator(plt.MaxNLocator(5))
    cbar_ax.ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    #ax1.set_title("Kaiming Initialization", fontsize=36)
    #ax1.tick_params(axis='both', labelsize=24)
    ax1.set_xticklabels([])  # Remove x-axis labels
    ax1.set_yticklabels([])  # Remove y-axis labels
    ax1.tick_params(axis='both', which='both', length=0)  # Remove tick marks
    # Add black border
    for spine in ax1.spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(2)
    plt.savefig("./plots/kaiming_init.png", dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.savefig("./plots/kaiming_init.pdf", dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close(fig1)
    
    # Draw SemSVD-Init reconstruction matrix
    fig2, ax2 = plt.subplots(figsize=(36, 24))
    cbar = sns.heatmap(semsvd_reconstruction.detach().cpu().numpy(), ax=ax2, cmap="RdBu", cbar=True, square=True,
                xticklabels=100, yticklabels=100)
    cbar_ax = cbar.collections[0].colorbar
    cbar_ax.ax.tick_params(labelsize=36)
    # Set 5 legend ticks
    cbar_ax.ax.yaxis.set_major_locator(plt.MaxNLocator(5))
    cbar_ax.ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    #ax2.set_title("Semantic-Enhanced SVD Init", fontsize=36)
    #ax2.tick_params(axis='both', labelsize=24)
    ax2.set_xticklabels([])  # Remove x-axis labels
    ax2.set_yticklabels([])  # Remove y-axis labels
    ax2.tick_params(axis='both', which='both', length=0)  # Remove tick marks
    # Add black border
    for spine in ax2.spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(2)
    plt.savefig("./plots/semsvd_init.png", dpi=300, bbox_inches="tight", pad_inches=0.1)
    # Set font to avoid Type 3 fonts
    plt.rcParams['pdf.fonttype'] = 42
    plt.rcParams['font.family'] = 'DejaVu Sans'
    plt.savefig("./plots/semsvd_init.pdf", dpi=300, bbox_inches="tight", pad_inches=0.1)
    plt.close(fig2)
    
    # # Create a 3x3 subplot grid
    # fig, axes = plt.subplots(3, 3, figsize=(24, 18))
    # fig.suptitle(title, fontsize=16)
    
    # # Share color range for better comparison
    # vmin = min(original_weight.min(), original_weight1.min(), lora_reconstruction.min(), semsvd_reconstruction.min())
    # vmax = max(original_weight.max(), original_weight1.min(), lora_reconstruction.max(), semsvd_reconstruction.max())
    
    # # Normalize residual matrix color range to center at 0
    # residual_max = max(abs(lora_residual).max(), abs(semsvd_residual).max())
    # residual_norm = Normalize(vmin=-residual_max, vmax=residual_max)
    
    # # Expand lora_A from 512×8 to 512×512 (repeat 64 times per column)
    # lora_A_expanded = np.repeat(lora_A, 64, axis=0)  # Repeat 64 times along column direction
    # sns.heatmap(lora_A_expanded, ax=axes[2, 0], cmap="RdBu", cbar=True, square=True)
    # axes[2, 0].set_title("Slot A")

    # semsvd_B_expanded = np.repeat(semsvd_B, 64, axis=1)
    # sns.heatmap(semsvd_B_expanded, ax=axes[2, 1], cmap="RdBu", cbar=True, square=True)
    # axes[2, 1].set_title("SemSVD B")

    # semsvd_A_expanded = np.repeat(semsvd_A, 64, axis=0)
    # sns.heatmap(semsvd_A_expanded, ax=axes[2, 2], cmap="RdBu", cbar=True, square=True)
    # axes[2, 2].set_title("SemSVD A")
    
    # # Draw original weight matrix
    # sns.heatmap(original_weight, ax=axes[0, 0], cmap="RdBu", cbar=True, square=True)
    # axes[0, 0].set_title("Original Weight Matrix")
    
    # # Draw LoRA initialized reconstruction matrix
    # sns.heatmap(lora_reconstruction, ax=axes[0, 1], cmap="RdBu", cbar=True, square=True)
    # axes[0, 1].set_title("LoRA Initialization (Kaiming)")
    
    # # Draw SemSVD-Init reconstruction matrix
    # sns.heatmap(semsvd_reconstruction, ax=axes[0, 2], cmap="RdBu", cbar=True, square=True)
    # axes[0, 2].set_title("SemSVD-Init (Semantic-Enhanced SVD)")
    
    # # Draw LoRA residual matrix
    # sns.heatmap(lora_residual, ax=axes[1, 1], cmap="coolwarm", norm=residual_norm, cbar=True, square=True)
    # axes[1, 1].set_title("LoRA Residual (Original - LoRA)")
    
    # # Draw PiSSA residual matrix
    # sns.heatmap(semsvd_residual, ax=axes[1, 2], cmap="coolwarm", norm=residual_norm, cbar=True, square=True)
    # axes[1, 2].set_title("SemSVD Residual (Original - SemSVD)")
    
    # # Draw residual comparison
    # residual_diff = abs(lora_residual) - abs(semsvd_residual)
    # diff_max = max(abs(residual_diff).max(), 1e-6)  # Ensure no division by zero
    # diff_norm = Normalize(vmin=-diff_max, vmax=diff_max)
    # sns.heatmap(residual_diff, ax=axes[1, 0], cmap="coolwarm", norm=diff_norm, cbar=True, square=True)
    # axes[1, 0].set_title("Residual Difference (LoRA - SemSVD)")
    # axes[1, 0].text(0.5, -0.1, "Positive values indicate smaller SemSVD residuals", transform=axes[1, 0].transAxes, 
    #                ha="center", va="center", color="red")
    
    # # Calculate and add reconstruction error
    # lora_mse = torch.mean((lora_residual ** 2)).item()
    # pissa_mse = torch.mean((semsvd_residual ** 2)).item()
    
    # # axes[0, 1].text(0.5, -0.1, f"MSE: {lora_mse:.3e}", transform=axes[0, 1].transAxes, 
    # #                ha="center", va="center", color="red")
    # # axes[0, 2].text(0.5, -0.1, f"MSE: {pissa_mse:.3e}", transform=axes[0, 2].transAxes, 
    # #                ha="center", va="center", color="red")
    
    # plt.tight_layout()
    # plt.subplots_adjust(top=0.92)  # Adjust top spacing for title
    
    # # Save image
    # plt.savefig("lora_vs_semsvd_initialization.png", dpi=300, bbox_inches="tight", pad_inches=0.1)
    # plt.show()
    
    # Calculate and print reconstruction error
    # lora_mse = torch.mean((original_weight - lora_reconstruction) ** 2).item()
    # semsvd_mse = torch.mean((original_weight1 - semsvd_reconstruction) ** 2).item()
    
    # print(f"LoRA initialization reconstruction error (MSE): {lora_mse:.3e}")
    # print(f"SemSVD-Init reconstruction error (MSE): {semsvd_mse:.3e}")
    # print(f"SemSVD-Init error reduction ratio: {(lora_mse - semsvd_mse) / lora_mse * 100:.2f}%")