import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colors

def grayscale_to_rgb(image):
    return image.repeat(3, 1, 1)

def add_edge(image, green=True):
    padded_image = F.pad(image, pad=(1, 1, 1, 1), mode='constant', value=0)

    edge_mask = torch.zeros_like(padded_image)
    edge_mask[:, 0, 0:] = 1
    edge_mask[:, -1, 0:] = 1
    edge_mask[:, 0:, 0] = 1
    edge_mask[:, 0:, -1] = 1

    edge_image = padded_image.clone()
    if green:
        edge_image[1, ...] += edge_mask[1, ...]
    else:
        edge_image[0, ...] += edge_mask[0, ...]

    return edge_image

def add_edge_batch(images, success):
    return torch.stack([add_edge(img, green) for img, green in zip(images, success)])
    

def plot_heat_map(table, file_name):
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "Helvetica",
        "font.size" : 20
    })

    # Define the two colors (medium purple and dark purple) using RGBA tuples
    # color1 = (0.576, 0.439, 0.858, 1)  # medium purple
    # color2 = (0.294, 0.000, 0.510, 1)  # dark purple
    # color1 = (0.678, 0.847, 0.902, 1)  # shallow blue
    # color2 = (0.0, 0.0, 0.545, 1)      # dark blue
    color1 = (0.733, 0.733, 0.733, 1)  # elegant grey
    color2 = (0.0, 0.0, 0.0, 1)        # black

    # Create a custom colormap
    cmap = LinearSegmentedColormap.from_list("custom_cmap", [color1, color2])

    fig = plt.figure()
    ax = plt.gca()
    # im = ax.imshow(table, cmap="YlGnBu", vmin=0, vmax=1)
    # im = ax.imshow(table, cmap="cividis", vmin=0, vmax=1)
    im = ax.imshow(table, cmap=cmap, vmin=0, vmax=1)
    NCOL = table.shape[1]
    NROW = table.shape[0]

    for ii in range(NCOL):
        for jj in range(NROW):
            text = ax.text(jj, ii, f"${table[ii, jj] * 100:.1f}$",
                           ha="center", va="center", color="w")

    ax.axis('off')
    fig.tight_layout()
    fig.savefig(file_name, bbox_inches='tight', pad_inches=0.0)
    plt.close(fig)


def plot_heat_map_2(table, file_name):
    fig = plt.figure()
    ax = plt.gca()
    im = ax.imshow(table)
    NCOL = table.shape[1]
    NROW = table.shape[0]

    for ii in range(NCOL):
        for jj in range(NROW):
            text = ax.text(jj, ii, f"{table[ii, jj]:.5f}",
                           ha="center", va="center", color="w")

    fig.tight_layout()
    fig.savefig(file_name, pad_inches=0.0)
    plt.close(fig)


