from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import os


def vis_retrieval_results(
    adv_images_list,
    adv_texts_list,
    scores_i2t,
    scores_t2i,
    VIS_DIR,
    img2txt=None,
    txt2img=None,
    show_n=5,
    top_k=5,
):
    """
    Visualize the retrieval results.
    """

    retrieved_texts = {}

    # Images->Text
    img_idx2txt = {}
    for index, score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # top N results
        txt_list = [adv_texts_list[i] for i in inds[:top_k]]
        # worst N results
        inds_worst = np.argsort(score)
        txt_list_worst = [adv_texts_list[i] for i in inds_worst[:top_k]]

        # correct or incorrect
        is_correct_list = [False] * top_k
        if img2txt is not None:
            is_correct_list = []
            for i in inds[:top_k]:
                is_correct = i in img2txt[index]
                is_correct_list.append(is_correct)

        retrieved_texts[index] = {}

        s = ""
        for i, txt in enumerate(txt_list):
            is_correct = "ooo " if is_correct_list[i] else ""
            
            s += f"{is_correct}Top {i+1}: {txt}\n"
            retrieved_texts[index][f"Top {i+1}"] = f"{is_correct} {txt}"
        s += "\n"
        for i, txt in enumerate(txt_list_worst):
            s += f"Worst {i+1}: {txt}\n"

        img_idx2txt[index] = s

        if len(img_idx2txt) >= show_n:
            break

    # save retrieved_texts
    save_txt_path = os.path.join(VIS_DIR, "retrieved_texts.txt")
    with open(save_txt_path, "w") as f:
        for index, txt_dict in retrieved_texts.items():
            f.write(f"Image {index}:\n")
            for key, txt in txt_dict.items():
                f.write(f"{key}: {txt}\n")
            f.write("\n")

    # Text->Images
    txt2img_idx_dict = {}
    for index, score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        # top N results
        img_idx_list = [i for i in inds[:top_k]]
        # worst N results
        inds_worst = np.argsort(score)
        img_idx_list_worst = [i for i in inds_worst[:top_k]]

        # correct or incorrect
        is_correct_list = [False] * top_k
        if txt2img is not None:
            is_correct_list = []
            for i in inds[:top_k]:
                is_correct = i == txt2img[index]
                is_correct_list.append(is_correct)

        txt = adv_texts_list[index]
        img_idx_dict = {}
        for i, img_idx in enumerate(img_idx_list):
            is_correct = "ooo " if is_correct_list[i] else ""
            img_idx_dict[f"{is_correct}Top {i+1}"] = img_idx
        for i, img_idx in enumerate(img_idx_list_worst):
            img_idx_dict[f"Worst {i+1}"] = img_idx

        txt2img_idx_dict[txt] = img_idx_dict

        if len(txt2img_idx_dict) >= show_n * 5:
            break

    # show Images->Text
    # figure with
    #   show_n rows, 2 columns
    #   left Image, right text
    fig_size_x = 20
    fig_size_y = show_n * 5
    plt.subplots(show_n, 2, figsize=(fig_size_x, fig_size_y))
    for i, (img_idx, txt) in enumerate(img_idx2txt.items()):
        numpy_image = (adv_images_list[img_idx] * 255).astype(np.uint8)
        resized_img = np.array(Image.fromarray(numpy_image).resize((224, 224)))
        resized_img = resized_img.astype(np.uint8)
        plt.subplot(show_n, 2, 2 * i + 1)
        plt.imshow(resized_img)
        plt.axis("off")
        plt.subplot(show_n, 2, 2 * i + 2)
        plt.text(0.5, 0.5, txt, ha="center", va="center", wrap=True)
        plt.axis("off")
    # save fig
    save_fig_path = os.path.join(VIS_DIR, "img2txt.png")
    plt.savefig(save_fig_path, bbox_inches="tight")
    plt.show()

    # show Text->Images
    # figure with
    #   len(txt2img_idx_dict) rows, 1 + top_k*2 columns
    #   left text, right top_k*2 images
    rows = len(txt2img_idx_dict)
    cols = 1 + top_k * 2
    fig, axes = plt.subplots(rows, cols, figsize=(cols*3, rows*3))

    for i, (txt, img_idx_dict) in enumerate(txt2img_idx_dict.items()):
        axes[i, 0].text(0, 1.1, txt, ha="left", va="top", wrap=True)
        axes[i, 0].axis("off")
        for j, (title, img_idx) in enumerate(img_idx_dict.items()):
            # print(title, img_idx, j)
            numpy_image = (adv_images_list[img_idx] * 255).astype(np.uint8)
            resized_img = np.array(Image.fromarray(numpy_image).resize((224, 224)))
            resized_img = resized_img.astype(np.uint8)
            axes[i, j + 1].imshow(resized_img)
            axes[i, j + 1].axis("off")
            # put text above the image
            axes[i, j + 1].text(
                0.5, 1.2, title, ha="center", va="center", wrap=True
            )
            axes[i, j].axis("off")
    save_fig_path = os.path.join(VIS_DIR, "txt2img.png")
    plt.savefig(save_fig_path, bbox_inches="tight")
    plt.show()


def vis_img_txt_pairs(images, texts, VIS_DIR, file_name, show_n=5):
    """
    Visualize the image-text pairs.
    """

    # figure with
    #   show_n rows, 2 columns
    #   left Image, right text
    fig_size_x = 10
    fig_size_y = show_n * 5
    plt.subplots(show_n, 2, figsize=(fig_size_x, fig_size_y))
    for i in range(show_n):
        numpy_image = (images[i] * 255).astype(np.uint8)
        # print(numpy_image.shape, numpy_image.dtype)
        resized_img = np.array(Image.fromarray(numpy_image).resize((224, 224)))
        resized_img = resized_img.astype(np.uint8)
        plt.subplot(show_n, 2, 2 * i + 1)
        plt.imshow(resized_img)
        plt.axis("off")
        plt.subplot(show_n, 2, 2 * i + 2)

        # fold texts: 1 row should be 5 words
        text = texts[i]
        words = text.split()
        n_words = len(words)
        n_rows = n_words // 5 + 1
        for j in range(n_rows):
            start = j * 5
            end = min((j + 1) * 5, n_words)
            line = " ".join(words[start:end])
            plt.text(0.5, 0.9 - j * 0.1, line, ha="center", va="center", wrap=True)
        # plt.text(0.5, 0.5, texts[i], ha="center", va="center", wrap=True)
        plt.axis("off")
    # save fig
    save_fig_path = os.path.join(VIS_DIR, f"img_txt_pairs_{file_name}.png")
    plt.savefig(save_fig_path, bbox_inches="tight")
