import os

import torch
import torch.nn.functional as F
from torchvision.utils import save_image

import seaborn as sns
import matplotlib.pyplot as plt

from dataset import test_data

def save_heatmap(matrix, filename='heatmap.png', cmap='viridis', annot=False):
    """
    Plot and save a heatmap from a matrix.

    Args:
        matrix (2D array-like): The input matrix to plot.
        filename (str): Output filename (e.g., 'heatmap.png').
        cmap (str): Colormap for the heatmap.
        annot (bool): If True, annotate cells with their values.
    """
    plt.figure(figsize=(6, 6))
    sns.heatmap(matrix, cmap=cmap, annot=annot, cbar=False, xticklabels=False, yticklabels=False)
    plt.tight_layout()
    plt.savefig(filename, dpi=300)
    plt.close()

def format_img(ele):
    ele = ele.unsqueeze(1)
    ele = F.interpolate(ele, (128, 128))
    ele = 1 - ele
    return ele

def save_example_plot(model, model_type):
    output_path = "outputs"
    os.makedirs(output_path, exist_ok=True)
    with torch.no_grad():
        for t in ["train", "test"]:
            tgt = test_data("train")
            img = tgt[:, :-1]
            img = img.cuda()
            tgt = tgt.cuda()
            pred, _ = model(img)
            ele = format_img(pred[0].detach().clone().cpu())
            save_image(ele, os.path.join(output_path, "{}_{}_pred.jpg".format(model_type, t)), nrow=3)
            ele = format_img(tgt[0].detach().clone().cpu())
            save_image(ele, os.path.join(output_path, "{}_{}_tgt.jpg".format(model_type, t)), nrow=3)
            save_heatmap(model.ret_dict["attn_map"][0][0].detach().clone().cpu(), os.path.join(output_path, '{}_{}_heatmap.png'.format(model_type, t)))
