import os
import textwrap
import torch as t
import numpy as np
import pandas as pd
from PIL import Image, ImageOps, ImageFilter
from tqdm import tqdm
import seaborn as sns
from matplotlib import pyplot as plt

from utils import utils
from utils.itdiffusion import DiffusionModel
from utils.stablediffusion import StableDiffuser
from configs.visual_configs import parse_args_and_update_config
import cv2
import gc

import torchvision.transforms as T
import json
import clip
device = "cuda" if t.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device = device)

plt.style.use('seaborn-v0_8-paper')
sns.set(style="whitegrid")
sns.set_context("paper", font_scale=1, rc={"lines.linewidth": 2.5})


def extract_pid_mask(img):
    img = (cv2.normalize(np.array(img), None, 0.0, 1.0, cv2.NORM_MINMAX) * 255).astype(np.uint8)
    mask = img > (img.mean() + 1 * img.std())

    return mask


def extract_intersect_mask(img1, img2):
    img1 = (cv2.normalize(np.array(img1), None, 0.0, 1.0, cv2.NORM_MINMAX) * 255).astype(np.uint8)
    mask1 = img1 > (img1.mean() + 1 * img1.std())
    img2 = (cv2.normalize(np.array(img2), None, 0.0, 1.0, cv2.NORM_MINMAX) * 255).astype(np.uint8)
    mask2 = img2 > (img2.mean() + 1 * img2.std())
    mask = mask1 * mask2

    return mask

    
@t.no_grad
def text_img_clip_sim(img, text):
    transform = T.ToPILImage()
    img = preprocess(transform(img.permute(2, 0, 1))).unsqueeze(0).to(device)
    img_feats = clip_model.encode_image(img)
    img_feats = img_feats.cpu().numpy()
    img_feats = img_feats / np.linalg.norm(img_feats, axis = 1).reshape((-1, 1))

    tokens = clip.tokenize([text]).to(device) 
    text_feats = clip_model.encode_text(tokens)
    text_feats = text_feats.cpu().numpy()
    text_feats = text_feats / np.linalg.norm(text_feats, axis = 1).reshape((-1, 1))

    score = img_feats[0] * text_feats[0]
    score = np.sum(score, axis = -1)

    return score


def get_clip_sim(img, mask, obj1, obj2):
    masked_img = img * mask[..., None]
    obj1 = obj1[0]
    obj2 = obj2[0]

    score1 = text_img_clip_sim(masked_img, obj1)
    score2 = text_img_clip_sim(masked_img, obj2)
    
    return [score1, score2]


def find_intersection(img1, img2, caption, obj1, obj2, where):
    # mask1 = img1 > 0.001
    # mask2 = img2 > 0.001
    img1 = (cv2.normalize(np.array(img1), None, 0.0, 1.0, cv2.NORM_MINMAX) * 255).astype(np.uint8)
    # mask1, _ = cv2.threshold(np.array(img1), 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    mask1 = img1 > (img1.mean() + 1.5 * img1.std())
    img2 = (cv2.normalize(np.array(img2), None, 0.0, 1.0, cv2.NORM_MINMAX) * 255).astype(np.uint8)
    mask2 = img2 > (img2.mean() + 1.5 * img2.std())
    # mask2, _ = cv2.threshold(np.array(img2), 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    # res = (img1 + img2) / 2
    # res *= mask1
    # res *= mask2
    res = (img1 * mask1 * mask2 + img2 * mask1 * mask2) / 2

    # res = img1 + img2
    # thresh = res.mean() + 2 * res.std()
    # res = np.where(res > thresh, res / 2, 0)

    res = t.tensor(res)

    # res = np.exp(img1) * np.exp(img2)

    return res


def adjust_gamma(image, gamma=1.0):
	# build a lookup table mapping the pixel values [0, 255] to
	# their adjusted gamma values
	invGamma = 1.0 / gamma
	table = np.array([((i / 255.0) ** invGamma) * 255
		for i in np.arange(0, 256)]).astype("uint8")
	# apply gamma correction using the lookup table
	return cv2.LUT(image, table)


def min_max_norm(im, min = None, max = None, exp = False):
    # min-max normalization
    if exp:
        im = t.exp(im)

    if min is None:
        min = im.min()
    if max is None:
        max = im.max()

    max = max - min
    im -= min
    im /= max

    final = []
    if len(im.shape) == 2:
        im = im.unsqueeze(0)

    for i in range(len(im)):
        brightness = 0
        contrast = 1.2
        image = np.uint8(im[i].cpu().numpy() * 255)
        img = cv2.addWeighted(image, contrast, np.zeros(image.shape, image.dtype), 0, brightness) 

        final.append(img / 255.)

    final = np.stack(final)

    return final

def add_noise(itd, sdm, latent_images, logsnrs):
    '''
     Calculate noisy images at multiple SNRs for visualization use.
    '''
    noisy_images_list = []
    for i in tqdm(range(len(logsnrs))):
        logsnr = logsnrs[i] * t.ones(latent_images.shape[0]).to(logsnrs.device)
        z, _ = itd.noisy_channel(latent_images, logsnr.to(latent_images.device))
        noisy_images = sdm.decode_latents(z)
        noisy_images_list.append(t.tensor(noisy_images))
    return t.stack(noisy_images_list)

def plot_mmse_curve(logsnrs, mi, mi_appx, cmi, cmi_appx):
    fig, ax = plt.subplots(2, 5, figsize=(23, 10))
    colors = ['skyblue', 'lightcoral']
    titles = ['airplane', 'bear', 'bed', 'cat', 'dog', 'elephant', 'horse', 'person', 'teddy bear', 'zebra']
    labels = ['$E[(\epsilon - \hat \epsilon_\\alpha(x))^2] - E[(\epsilon - \hat \epsilon_\\alpha(x|y*))^2]$',
              '$E[(\hat \epsilon_\\alpha(x) - \hat \epsilon_\\alpha(x|y_*))^2]$',
              '$E[(\epsilon - \hat \epsilon_\\alpha(x|c))^2] - E[(\epsilon - \hat \epsilon_\\alpha(x|y))^2]$',
              '$E[(\hat \epsilon_\\alpha(x|c) - \hat \epsilon_\\alpha(x|y))^2]$']
    for i in range(2):
        for j in range(5):
            ax[i][j].set_ylim(-15, 35)
            ax[i][j].plot(logsnrs, mi[i * 5 + j], label=labels[0], linestyle=':', color=colors[0])
            ax[i][j].plot(logsnrs, mi_appx[i * 5 + j], label=labels[1], linestyle='-', color=colors[0])
            ax[i][j].plot(logsnrs, cmi[i * 5 + j], label=labels[2], linestyle=':', color=colors[1])
            ax[i][j].plot(logsnrs, cmi_appx[i * 5 + j], label=labels[3], linestyle='-', color=colors[1])
            ax[i][j].set_title(f'{titles[i * 5 + j]}', fontsize=15)
            ax[i][j].set_ylabel('bits', fontsize=15)
            ax[i][j].set_xlabel('$\\alpha$', fontsize=15)
    ax[0][4].legend(loc='upper right', ncol=2, fontsize=15)
    return fig

def plot_mmse_and_mi_appx(img, nimgs, mi_appx, cmi_appx, mse_appx, cmse_appx, logsnrs):
    ylabels = ['Add noise', '$E[(\hat \epsilon_\\alpha(x) - \hat \epsilon_\\alpha(x|y_*))^2]$', '$E[(\hat \epsilon_\\alpha(x|c) - \hat \epsilon_\\alpha(x|y))^2]$']
    snr_num = len(logsnrs)
    row_num = len(ylabels)
    cmap = 'jet'
    fig, ax = plt.subplots(row_num, snr_num + 1, figsize=(14, 4.5), frameon=False)
    ax[0][-1].imshow(t.clamp(img, 0, 1))
    ax[0][-1].set_title('Real COCO')
    mse_appxs= min_max_norm(t.stack([mse_appx, cmse_appx]))

    for j in range(snr_num):
        nimg_ = t.clamp(nimgs[j], 0, 1)
        mse_appx_ = mse_appxs[0][j]
        cmse_appx_ = mse_appxs[1][j]
        ax[0][j].imshow(nimg_)
        ax[0][j].set_title(f"$\\alpha$ = {logsnrs[j]:.2f}")
        ax[1][j].imshow(mse_appx_, cmap=cmap)
        ax[2][j].imshow(cmse_appx_, cmap=cmap)

    mi_appx_ = min_max_norm(mi_appx)
    cmi_appx_ = min_max_norm(cmi_appx)
    ax[1][-1].imshow(mi_appx_, cmap=cmap)
    ax[1][-1].set_title('$\mathfrak{i}^o(x;y_*)$')
    ax[2][-1].imshow(cmi_appx_, cmap=cmap)
    ax[2][-1].set_title('$\mathfrak{i}^o(x;y_*|c)$')

    for i in range(row_num):
        for j in range(1, snr_num + 1):
            ax[i][j].set_xticks([])
            ax[i][j].set_yticks([])

    for i in range(row_num):
        if i > 0:
            ax[i][0].set_ylabel(ylabels[i], fontsize=7)
        else:
            ax[i][0].set_ylabel(ylabels[i])
        ax[i][0].set_xticks([])
        ax[i][0].set_yticks([])

    return fig


def plot_heatmaps_metric(img, caption, obj, metric, ctr, 
                      norm_over_all_data = False, norm_over_all_maps = False, exp = False, cmap = "jet"):
    titles = ['Image', 'Metric']
    sample_num = len(titles)
    fig, ax = plt.subplots(1, sample_num, figsize=(7, 3))

    print(caption[0], obj)
    print(f"Metric: Min = {metric[ctr].min()}, {metric[ctr].max()}, Mean = {metric[ctr].mean()}, Std Dev = {metric[ctr].std()}")
    print()

    if norm_over_all_data:
        metric_ = min_max_norm(metric[ctr], metric.min(), metric.max(), exp = exp)[0]

    else:
        # metric_ = min_max_norm(metric[ctr], exp = exp)[0]
        pass

    ax[0].imshow(t.clamp(img, 0, 1))
    ax[1].imshow(metric_, cmap=cmap, vmax=1, vmin=0.1)

    for i in range(sample_num):
        ax[i].set_title(titles[i], fontsize=10)
        ax[i].axis('off')

    # split caption if it's too long
    wrapped_text = "\n".join(textwrap.wrap(caption[0], width=70))
    objs_text = "\n".join(textwrap.wrap(f"'{obj[0]}'", width=70))
    text = f'c = {wrapped_text}\ny = {objs_text}'
    ax[0].text(0, 730 + 50 * len(wrapped_text.split('\n')), text, va="bottom", ha='left', fontsize=10)

    return fig


def plot_heatmaps_pid(img, caption, obj1, obj2, redun, syn, uniq, ctr, 
                      norm_over_all_data = False, norm_over_all_maps = False, exp = False, cmap = "jet"):
    titles = ['Image', 'Redundancy', 'Uniqueness1', 'Uniqueness2', 'Synergy']
    sample_num = len(titles)
    fig, ax = plt.subplots(1, sample_num, figsize=(7, 3))

    print(caption[0], obj1, obj2)
    print(f"Redundancy: Min = {redun[ctr].min()}, {redun[ctr].max()}, Mean = {redun[ctr].mean()}, Std Dev = {redun[ctr].std()}")
    print(f"Uniqueness1: Min = {uniq[ctr][0].min()}, {uniq[ctr][0].max()}, Mean = {uniq[ctr][0].mean()}, Std Dev = {uniq[ctr][0].std()}")
    print(f"Uniqueness2: Min = {uniq[ctr][1].min()}, {uniq[ctr][1].max()}, Mean = {uniq[ctr][1].mean()}, Std Dev = {uniq[ctr][1].std()}")
    print(f"Synergy: Min = {syn[ctr].min()}, {syn[ctr].max()}, Mean = {syn[ctr].mean()}, Std Dev = {syn[ctr].std()}")
    print()

    if norm_over_all_data:
        redun_ = min_max_norm(redun[ctr], redun.min(), redun.max(), exp = exp)
        uniq1_ = min_max_norm(uniq[ctr][0], uniq[:, 0].min(), uniq[:, 0].max(), exp = exp)
        uniq2_ = min_max_norm(uniq[ctr][1], uniq[:, 1].min(), uniq[:, 1].max(), exp = exp)
        syn_ = min_max_norm(syn[ctr], syn.min(), syn.max(), exp = exp)

    elif norm_over_all_maps:
        all = min_max_norm(t.stack([redun[ctr], uniq[ctr][0], uniq[ctr][1], syn[ctr]]))
        redun_ = all[0]
        uniq1_ = all[1]
        uniq2_ = all[2]
        syn_ = all[3]

    else:
        redun_ = min_max_norm(redun[ctr], exp = exp)[0]
        uniq_ = min_max_norm(uniq[ctr], exp = exp)
        uniq1_ = uniq_[0]
        uniq2_ = uniq_[1]
        # uniq1_ = min_max_norm(uniq[ctr][0], exp = exp)
        # uniq2_ = min_max_norm(uniq[ctr][1], exp = exp)
        syn_ = min_max_norm(syn[ctr], exp = exp)[0]

    ax[0].imshow(t.clamp(img, 0, 1))
    ax[1].imshow(redun_, cmap=cmap, vmax=1, vmin=0.1)
    ax[2].imshow(uniq1_, cmap=cmap, vmax=1, vmin=0.1)
    ax[3].imshow(uniq2_, cmap=cmap, vmax=1, vmin=0.1)
    ax[4].imshow(syn_, cmap=cmap, vmax=1, vmin=0.1)

    for i in range(sample_num):
        ax[i].set_title(titles[i], fontsize=8)
        ax[i].axis('off')

    # split caption if it's too long
    wrapped_text = "\n".join(textwrap.wrap(caption[0], width=70))
    objs_text = "\n".join(textwrap.wrap(f"'{obj1[0]}' vs '{obj2[0]}'", width=70))
    text = f'c = {wrapped_text}\ny = {objs_text}\nnorm={syn_.mean()}, orig={syn[ctr].mean()}'
    ax[0].text(0, 730 + 50 * len(wrapped_text.split('\n')), text, va="bottom", ha='left', fontsize=10)

    return fig

def plot_redun(img, vis, caption, obj1, obj2, norm_over_all_data = False, norm_over_all_maps = False, exp = False, cmap = "jet"):
    titles = ['Image', 'Redundancy', 'CMI', 'MI', 'DAAM']
    sample_num = len(titles)
    fig, ax = plt.subplots(1, sample_num, figsize=(7, 3))

    ax[0].imshow(t.clamp(img, 0, 1))
    ax[1].imshow(min_max_norm(vis['redun'], exp = exp)[0], cmap=cmap, vmax=1, vmin=0.1)
    ax[2].imshow(min_max_norm(vis['cmi'], exp = exp)[0], cmap=cmap, vmax=1, vmin=0.1)
    ax[3].imshow(min_max_norm(vis['mi'], exp = exp)[0], cmap=cmap, vmax=1, vmin=0.1)
    ax[4].imshow(min_max_norm(vis['attn'], exp = exp)[0], cmap=cmap, vmax=1, vmin=0.1)

    for i in range(sample_num):
        ax[i].set_title(titles[i], fontsize=8)
        ax[i].axis('off')

    # split caption if it's too long
    wrapped_text = "\n".join(textwrap.wrap(caption[0], width=70))
    objs_text = "\n".join(textwrap.wrap(f"'{obj1[0]}' vs '{obj2[0]}'", width=70))
    text = f'c = {wrapped_text}\ny = {objs_text}'
    ax[0].text(0, 730 + 50 * len(wrapped_text.split('\n')), text, va="bottom", ha='left', fontsize=10)

    return fig

def plot_img_edit(vis, caption, obj, norm_over_all_data = False, norm_over_all_maps = False, exp = False, cmap = "jet"):
    titles = ['Original', 'Intervention', 'Redundancy', 'Uniqueness', 'CMI', 'MI', 'DAAM']
    sample_num = len(titles)
    fig, ax = plt.subplots(1, sample_num, figsize=(7, 3))

    # print(caption[0], obj)
    # print(f"Metric: Min = {vis['uniq'].min()}, {vis['uniq'].max()}, Mean = {vis['uniq'].mean()}, Std Dev = {vis['uniq'].std()}")
    # print()

    ax[0].imshow(vis['orig'])
    ax[1].imshow(vis['mod'])
    ax[2].imshow(min_max_norm(vis['redun'], exp = exp)[0], cmap=cmap, vmax=1, vmin=0.1)
    ax[3].imshow(min_max_norm(vis['uniq'], exp = exp)[0], cmap=cmap, vmax=1, vmin=0.1)
    ax[4].imshow(min_max_norm(vis['cmi'], exp = exp)[0], cmap=cmap, vmax=1, vmin=0.1)
    ax[5].imshow(min_max_norm(vis['mi'], exp = exp)[0], cmap=cmap, vmax=1, vmin=0.1)
    ax[6].imshow(min_max_norm(vis['attn'], exp = exp)[0], cmap=cmap, vmax=1, vmin=0.1)

    for i in range(sample_num):
        ax[i].set_title(titles[i], fontsize=8)
        ax[i].axis('off')

    # split caption if it's too long
    wrapped_text = "\n".join(textwrap.wrap(caption[0], width=70))
    objs_text = "\n".join(textwrap.wrap(f"'{obj[0]}'", width=70))
    text = f'c = {wrapped_text}\ny = {objs_text}'
    ax[0].text(0, 730 + 50 * len(wrapped_text.split('\n')), text, va="bottom", ha='left', fontsize=10)

    return fig

def plot_heatmaps(img, caption, obj, mi, cmi, attn, type_):
    titles = ['Real COCO', '$\mathfrak{i}^o(x;y_*|c)$', '$\mathfrak{i}^o(x;y_*)$', 'Attention']
    sample_num = len(titles)
    cmap = 'jet'
    fig, ax = plt.subplots(1, sample_num, figsize=(7, 3))

    cmi_mi = min_max_norm(t.stack([cmi, mi]))
    cmi_ = cmi_mi[0]
    mi_ = cmi_mi[1]
    attn_ = min_max_norm(attn)
    ax[0].imshow(t.clamp(img, 0, 1))
    ax[1].imshow(cmi_, cmap=cmap, vmax=1, vmin=0.1)
    ax[2].imshow(mi_, cmap=cmap, vmax=1, vmin=0.1)
    ax[3].imshow(attn_, cmap=cmap, vmax=1, vmin=0.1)

    for i in range(sample_num):
        ax[i].set_title(titles[i], fontsize=14)
        ax[i].axis('off')

    # split caption if it's too long
    wrapped_text = "\n".join(textwrap.wrap(caption[0], width=52))
    text = f'c = {wrapped_text}\n$y_*$ = {obj} ({type_})'
    ax[0].text(0, 700 + 50 * len(wrapped_text.split('\n')), text, va="bottom", ha='left', fontsize=13)

    return fig

def plot_overlay(im, heatmap, fig, ax, normalize=False, vmax=None,
                 title=None, fontsize=14, last=False, inset_text=None):
    heatmap = heatmap.squeeze().cpu().numpy()
    if type(im) is Image.Image:
        im = np.array(im) / 255.
    else:
        im = (np.array(im) + 1) / 2  # 0..1 image
    ninety = np.percentile(heatmap.flatten(), 90)
    norm_heat = np.clip((heatmap - heatmap.min()) / (ninety- heatmap.min() + 1e-8), 0, 1)
    if normalize:
        x = norm_heat
        vmax = 1
    else:
        x = heatmap
        if vmax is None:
            vmax = heatmap.max()
    out = ax.imshow(x, vmin=0, vmax=vmax, cmap='jet')
    if last:
        cax = ax.inset_axes([1.02, 0.1, 0.1, 0.8])
        cax.axis('off')
        cbar = fig.colorbar(out, ax=cax, orientation="vertical", shrink=1)
        cbar.set_ticks([])  # Add ticks at 0 and the maximum value

    alpha = (1 - norm_heat)[:, :, np.newaxis]
    im = np.concatenate((im, alpha), axis=-1)
    ax.imshow(im)
    ax.set_title(title, fontsize=fontsize)
    ax.axis('off')
    if inset_text is not None:
        ax.text(
            0.05, 0.05,  # Adjust the coordinates for the starting position
            "{:.1f} bits total".format(inset_text),
            fontsize=fontsize,
            color="white",
            transform=ax.transAxes,  # Use axes-relative coordinates
            verticalalignment="bottom",  # Align text to the bottom
            horizontalalignment="left",  # Align text to the left
        )

def plot_img(im, ax, title=None, fontsize=14):
    if type(im) is Image.Image:
        im = np.array(im) / 255.
    else:
        im = (np.array(im) + 1) / 2  # 0..1 image
    ax.imshow(im)
    ax.set_title(title, fontsize=fontsize)
    ax.axis('off')


def plot_text(ax, c, y, yp, fontsize=14):
    ax.text(
        0.05, 0.9,  # Adjust the coordinates for the starting position of string1
        c,
        fontsize=fontsize,  # Adjust the font size as needed
        color='black',  # Color of string1 (black)
        verticalalignment="top",  # Align text to the top
        horizontalalignment="left",  # Align text to the left
    )
    ax.text(
        0.05, 0.5,  # Adjust the coordinates for the starting position of string1
        y,
        fontsize=fontsize,  # Adjust the font size as needed
        color='green',  # Color of string1 (black)
        verticalalignment="top",  # Align text to the top
        horizontalalignment="left",  # Align text to the left
    )
    ax.text(
        0.05, 0.25,  # Adjust the coordinates for the starting position of string1
        yp,
        fontsize=fontsize,  # Adjust the font size as needed
        color='red',  # Color of string1 (black)
        verticalalignment="top",  # Align text to the top
        horizontalalignment="left",  # Align text to the left
    )
    ax.axis('off')

    
def compute_corr(vis):
    # pixel change

    # pixel-level
    pixel_changes.append(pixel_change.flatten())
    attentions_pixel.append(heat_map.flatten())
    cmis_pixel.append(cmi_pixel.flatten())
    # image-level
    pixel_changes_score.append(pixel_change.sum())
    attentions.append(heat_map.sum())
    cmis.append(cmi)

    corr_attn_score = np.corrcoef(attentions, pixel_changes_score)[0, 1]
    corr_cmi_score = np.corrcoef(cmis, pixel_changes_score)[0, 1]

    print('Image level:')
    print('Correlation between attention and pixel change: ', corr_attn_score)
    print('Correlation between CMI and pixel change: ', corr_cmi_score)

    corr_attns = []
    corr_cmis = []
    for i in range(len(cmis)):
        corr_attn = np.corrcoef(attentions_pixel[i], pixel_changes[i])[0, 1]
        corr_cmi = np.corrcoef(cmis_pixel[i], pixel_changes[i])[0, 1]
        corr_attns.append(corr_attn)
        corr_cmis.append(corr_cmi)

    print('Pixel level (avg on images):')
    print('Mean correlation between attention and pixel change: ', np.mean(corr_attns))
    print('Mean correlation between CMI and pixel change: ', np.mean(corr_cmis))


    # calculate bootstrapping error bar
    corr_attn_score_list, corr_cmi_score_list = [], []
    corr_attn_list, corr_cmi_list = [], []
    attentions = np.array(attentions)
    pixel_changes_score = np.array(pixel_changes_score)
    cmis = np.array(cmis)

    selected_idx = np.random.randint(len(cmis), size=(100, int(0.9 * len(cmis)))) # repeat 100 times
    for idx in selected_idx:
        # image level
        corr_attn_score_ = np.corrcoef(attentions[idx], pixel_changes_score[idx])[0, 1]
        corr_cmi_score_ = np.corrcoef(cmis[idx], pixel_changes_score[idx])[0, 1]
        corr_attn_score_list.append(corr_attn_score_)
        corr_cmi_score_list.append(corr_cmi_score_)

        # pixel level
        corr_attns = []
        corr_cmis = []
        for i in idx:
            corr_attn = np.corrcoef(attentions_pixel[i], pixel_changes[i])[0, 1]
            corr_cmi = np.corrcoef(cmis_pixel[i], pixel_changes[i])[0, 1]
            corr_attns.append(corr_attn)
            corr_cmis.append(corr_cmi)
        corr_attn_list.append(np.mean(corr_attns))
        corr_cmi_list.append(np.mean(corr_cmis))
    print('\n\n')
    print('Image level bootstrapping error:')
    print('Correlation between attention and pixel change: ', np.std(corr_attn_score_list) / 10)
    print('Correlation between CMI and pixel change: ', np.std(corr_cmi_score_list) / 10)

    print('Pixel level (avg on images) bootstrapping error:')
    print('Mean correlation between attention and pixel change: ', np.std(corr_attn_list) / 10)
    print('Mean correlation between CMI and pixel change: ', np.std(corr_cmi_list) / 10)

    print(f"Done with {ctr}")

    
def clip_similarity(img1, img2):
    img1 = preprocess(img1).unsqueeze(0).to(device)
    img2 = preprocess(img2).unsqueeze(0).to(device)

    imgs = t.cat([img1, img2], dim = 0)
    img_feats = clip_model.encode_image(imgs)
    img_feats = img_feats.cpu().numpy()
    img_feats = img_feats / np.linalg.norm(img_feats, axis = 1).reshape((-1, 1))

    score = img_feats[0] * img_feats[1]
    score = np.sum(score, axis = -1)

    return score


def main():
    config = parse_args_and_update_config()
    t.manual_seed(config.seed)

    # set hyper-parameters
    fig_out_dir = config.fig_out_dir
    res_in_dir = config.res_in_dir
    data_in_dir = config.data_in_dir
    sdm_version = config.sdm_version
    csv_name = config.csv_name
    visual_type = config.visual_type
    dataset_type = config.dataset_type
    z_sample_num = config.n_samples_per_point
    snr_num = config.num_steps
    int_mode = config.int_mode
    norm_over_all_data = config.norm_over_all_data
    norm_over_all_maps = config.norm_over_all_maps
    exp = config.exp
    cmap = config.cmap
    eval_metrics = config.eval_metrics

    if visual_type == 'mmse_curve':
        # assign noise levels
        logsnrs = t.linspace(-5.0, 7.0, 200)

        # load results
        in_path1 = './results/itd_full/mi/sdm_2_1_base-nll_2D-COCO-IT-logistic-1-100-mi-320.pt'
        in_path2 = './results/itd_full/cmi/sdm_2_1_base-nll_2D-COCO-IT-logistic-1-100-cmi-320.pt'
        results1 = t.load(in_path1)
        results2 = t.load(in_path2)
        mi = results1['mmses'][:, 0, :] - results1['mmses'][:, 1, :]  # 10 * snr_num
        mi_appx = results1['mmses_diff_appx'][:, 0, :]  # 10 * snr_num
        cmi = results2['mmses'][:, 0, :] - results2['mmses'][:, 1, :]  # 10 * snr_num
        cmi_appx = results2['mmses_diff_appx'][:, 0, :]  # 10 * snr_num

        # visualization
        print('Starting visualization...')
        out_path = os.path.join(fig_out_dir, f"{sdm_version}/COCO10/mmse_gap-curves")
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        fig = plot_mmse_curve(logsnrs, -mi, mi_appx, -cmi, cmi_appx)
        fig.savefig(os.path.join(out_path, f"mmses.png"), dpi=300)
        print('Done')

    elif visual_type == 'scatter_plot':
        if csv_name == 'COCO-IT':
            titles = ['airplane', 'bear', 'bed', 'cat', 'dog', 'elephant', 'horse', 'person', 'teddy bear', 'zebra']
        elif csv_name == 'COCO-WL':
            titles = ['verb', 'num', 'adj', 'adv', 'prep', 'pronoun', 'conj']

        img_dir = os.path.join(data_in_dir, f'val2017')
        csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
        annotation_file = os.path.join(data_in_dir, f'annotations/instances_val2017.json')
        dataset = utils.CocoDataset(img_dir, annotation_file, csv_dir)
        dataloader = t.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
        titles = []
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= 10:
                break
            titles.append(batch['obj'])

        # load results
        # in_path1 = os.path.join(res_in_dir, f'{sdm_version}-nll_1D-{csv_name}-logistic-1-200-mi.pt')
        # in_path2 = os.path.join(res_in_dir, f'{sdm_version}-nll_1D-{csv_name}-logistic-1-200-cmi.pt')
        in_path1 = './results/itd_full/mi/sdm_2_1_base-nll_2D-COCO-IT-logistic-1-100-mi-320.pt'
        in_path2 = './results/itd_full/cmi/sdm_2_1_base-nll_2D-COCO-IT-logistic-1-100-cmi-320.pt'
        results1 = t.load(in_path1)
        results2 = t.load(in_path2)
        mi = results1['pixel_mi'][:, 0]  # N * 3
        cmi = results2['pixel_mi'][:, 0]  # N * 3

        # calculate Pearson correlation
        correlation_coefficient = np.corrcoef(mi.reshape((len(mi), -1)), cmi.reshape((len(mi), -1)))[0, 1]

        # visualization
        print('Visualization starts ...')
        colors = plt.cm.rainbow(np.linspace(0, 1, len(titles)))
        plt.plot([0, max(mi.max(), cmi.max())], [0, max(mi.max(), cmi.max())], color='red', linestyle='--', label='$\mathfrak{i}^o(x;y_*) = \mathfrak{i}^o(x;y_*|c)$')
        for i in range(len(titles)):
            # plt.scatter(mi[i * 10:(i + 1) * 10], cmi[i * 10:(i + 1) * 10], color=colors[i], alpha=0.5, s=20, label=titles[i])
            plt.scatter(mi[i : i + 1], cmi[i :(i + 1) ], color=colors[i], alpha=0.5, s=20, label=titles[i])
        plt.text(80, 1, f'Pearson correlation: {correlation_coefficient :.2f}', fontsize=12, color='black')
        plt.legend(loc='upper left', ncol=2)
        plt.title(f'{csv_name}', fontsize=14)
        plt.xlabel('$\mathfrak{i}^o(x;y_*)$', fontsize=13)
        plt.ylabel('$\mathfrak{i}^o(x;y_*|c)$', fontsize=13)
        out_path = os.path.join(fig_out_dir, f"{sdm_version}/{csv_name}/")
        if not os.path.exists(out_path):
            os.makedirs(out_path)
        plt.savefig(os.path.join(out_path, f"scatter-plot-{csv_name}.png"), dpi=300)
        print('Done')

    elif visual_type == 'denoising_diffusion':
        # load diffusion models
        if sdm_version == 'sdm_2_0_base':
            sdm = StableDiffuser("stabilityai/stable-diffusion-2-base")
        elif sdm_version == 'sdm_2_1_base':
            sdm = StableDiffuser("stabilityai/stable-diffusion-2-1-base")
        latent_shape = (sdm.channels, sdm.width, sdm.height)
        itd = DiffusionModel(sdm.unet, latent_shape, logsnr_loc=1.0, logsnr_scale=2.0, clip=3.0, logsnr2t=sdm.logsnr2t)

        # assign noise levels
        logsnrs = t.linspace(-5.0, 7.0, 10).to(sdm.device)

        # load data
        # TODO: change CSV file to JSON file
        img_dir = os.path.join(data_in_dir, f'val2017')
        csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
        annotation_file = os.path.join(data_in_dir, f'annotations/instances_val2017.json')
        dataset = utils.CocoDataset(img_dir, annotation_file, csv_dir)
        dataloader = t.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
        img_list = []
        n_img_list = []
        for batch_idx, batch in enumerate(dataloader):
            image_batch = batch['image']
            latent_images = sdm.encode_latents(image_batch)
            print('Add noise to image...')
            noisy_images_batch = add_noise(itd, sdm, latent_images, logsnrs)
            print('Done\n')
            n_img_list.append(noisy_images_batch)
            img_list.append(image_batch)
        images = t.cat(img_list).squeeze()
        images = images.permute(0, 2, 3, 1)
        noisy_images = t.cat(n_img_list)
        noisy_images = noisy_images.permute(1, 0, 2, 3, 4)

        # load results
        in_path1 = os.path.join(res_in_dir, f'{sdm_version}-nll_2D-COCO10-logistic-1-200-mi.pt')
        in_path2 = os.path.join(res_in_dir, f'{sdm_version}-nll_2D-COCO10-logistic-1-200-cmi.pt')
        # in_path3 = os.path.join(res_in_dir, f'{sdm_version}-mse_2D-COCO10-logistic-50-10-mi.pt')
        in_path4 = os.path.join(res_in_dir, f'{sdm_version}-mse_2D-COCO10-logistic-50-10-cmi.pt')
        results1 = t.load(in_path1)
        results2 = t.load(in_path2)
        results3 = t.load(in_path2)
        # results3 = t.load(in_path3)
        results4 = t.load(in_path4)
        mi_appx = results1['pixel_mi'][:, 0, :, :]  # 100 * h * w
        cmi_appx = results2['pixel_mi'][:, 0, :, :]  # 100 * h * w
        mse_appx = results3['pixel_mmses_diff_appx'][:, 0, :, :]  # 100 * h * w
        cmse_appx = results4['pixel_mmses_diff_appx'][:, 0, :, :]  # 100 * h * w

        # visualization
        print('Starting visualization...')
        for i in tqdm(range(len(images))):
            out_path = os.path.join(fig_out_dir, f"{sdm_version}/{csv_name}")
            if not os.path.exists(out_path):
                os.makedirs(out_path)
            fig = plot_mmse_and_mi_appx(images[i], noisy_images[i], mi_appx[i], cmi_appx[i], mse_appx[i], cmse_appx[i], logsnrs)
            plt.subplots_adjust(wspace=0, hspace=0.15)
            fig.savefig(os.path.join(out_path, f"denoising_diffusion-{i}.png"), dpi=300)
            plt.close()
        print('Done')

    elif visual_type == 'pid':
        # load data
        # TODO: change CSV file to JSON file
        img_dir = os.path.join(data_in_dir, f'val2017')
        csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
        annotation_file = os.path.join(data_in_dir, f'annotations/instances_val2017.json')

        if dataset_type == "COCO-IT":
            dataset = utils.CocoDataset(img_dir, annotation_file, csv_dir)
        elif dataset_type == "coco_ours":
            dataset = utils.CocoDatasetOurs(img_dir, annotation_file, csv_dir)
        elif dataset_type == "custom":
            dataset = utils.CustomDataset(csv_dir)
        dataloader = t.utils.data.DataLoader(dataset, batch_size=1 , shuffle=False, num_workers=0)

        redun_list = []
        uniq_list = []
        syn_list = []

        files = os.listdir(res_in_dir)
        idxs = [int(file[:-3].split("-")[-1]) for file in files]
        ctr = 0
        for batch_idx in sorted(idxs):
            ctr += 1
            file = f"{sdm_version}-pid-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-{eval_metrics}-{batch_idx}.pt"
            print(batch_idx, file)

            pid_path = os.path.join(res_in_dir, file)
            redun = t.load(pid_path)['pixel_redun']
            print("Loaded redundancy")
            syn = t.load(pid_path)['pixel_syn']
            print("Loaded synergy")
            uniq = t.load(pid_path)['pixel_uniq']
            print("Loaded uniqueness")

            redun_list.append(redun)
            uniq_list.append(uniq)
            syn_list.append(syn)

            del redun, uniq, syn
            gc.collect()
            t.cuda.empty_cache()

        redun_list = t.cat(redun_list)
        uniq_list = t.cat(uniq_list)
        syn_list = t.cat(syn_list)

        # visualization
        print('Starting visualization...')
        out_path = os.path.join(fig_out_dir, f"{sdm_version}/{csv_name}/")
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        for ctr, batch in enumerate(dataloader):
            image_batch = batch['image']
            image = image_batch.permute(0, 2, 3, 1)[0]
            caption_batch = batch['caption']

            if dataset_type == "COCO-IT":
                obj1_batch = batch['category']
                obj2_batch = batch['context']
            
            else:
                obj1_batch = batch['obj1']
                obj2_batch = batch['obj2']

            fig = plot_heatmaps_pid(image, caption_batch, obj1_batch, obj2_batch, 
                                    redun_list, syn_list, uniq_list, ctr,
                                    norm_over_all_data, norm_over_all_maps, exp, cmap)
            # fig.savefig(os.path.join(out_path, f"{caption_batch[0]}_{obj1_batch[0]}_{obj2_batch[0]}.png"), dpi=300, bbox_inches='tight')
            fig.savefig(os.path.join(out_path, f"{ctr}.png"), dpi=300, bbox_inches='tight')
            plt.close(fig)
            print(f"Done with {ctr}")

        print('Done')
            
    elif visual_type in ['mi', 'cmi', 'attn']:
        # load data
        # TODO: change CSV file to JSON file
        img_dir = os.path.join(data_in_dir, f'val2017')
        csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
        annotation_file = os.path.join(data_in_dir, f'annotations/instances_val2017.json')

        if dataset_type == "COCO-IT":
            dataset = utils.CocoDataset(img_dir, annotation_file, csv_dir)
        elif dataset_type == "coco_ours":
            dataset = utils.CocoDatasetOurs(img_dir, annotation_file, csv_dir)
        elif dataset_type == "custom":
            dataset = utils.CustomDataset(csv_dir)
        dataloader = t.utils.data.DataLoader(dataset, batch_size=1 , shuffle=False, num_workers=0)

        metric_list = []

        files = os.listdir(res_in_dir)
        idxs = [int(file[:-3].split("-")[-1]) for file in files]
        for batch_idx in sorted(idxs):
            if visual_type in ['mi', 'cmi']:
                file = f"{sdm_version}-nll_2D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-{eval_metrics}-{batch_idx}.pt"
            elif visual_type in ['attn']:
                file = f"{sdm_version}-attnmaps-{csv_name}-{snr_num}-{batch_idx}.pt"

            print(batch_idx, file)

            metric_path = os.path.join(res_in_dir, file)
            if visual_type in ['mi', 'cmi']:
                metric = t.load(metric_path)['pixel_mi'][:, 0, ...]
            elif visual_type in ['attn']:
                metric = t.load(metric_path)['attnmaps'][:, ...]

            print("Loaded metric")

            metric_list.append(metric)

        metric_list = t.cat(metric_list)

        # visualization
        print('Starting visualization...')
        out_path = os.path.join(fig_out_dir, f"{sdm_version}/{csv_name}/")
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        num_itrs = 1
        if dataset_type == "custom":
            num_itrs = 2

        for ctr, batch in enumerate(dataloader):
            for itr in range(num_itrs):
                image_batch = batch['image']
                image = image_batch.permute(0, 2, 3, 1)[0]
                caption_batch = batch['caption']

                idx = ctr
                if dataset_type == "COCO-IT":
                    obj_batch = batch['category']
                
                else:
                    obj_batch = batch[f'obj{itr + 1}']
                    idx = ctr * 2 + itr

                fig = plot_heatmaps_metric(image, caption_batch, obj_batch, metric_list, idx,
                                        norm_over_all_data, norm_over_all_maps, exp, cmap)
                # fig.savefig(os.path.join(out_path, f"{caption_batch[0]}_{obj_batch[0]}.png"), dpi=300, bbox_inches='tight')
                fig.savefig(os.path.join(out_path, f"{ctr}_{obj_batch[0]}.png"), dpi=300, bbox_inches='tight')
                plt.close(fig)
                print(f"Done with {ctr}")

        print('Done')

    elif visual_type == 'bias':
        # load data
        csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
        dataset = utils.CustomDataset(csv_dir)
        dataloader = t.utils.data.DataLoader(dataset, batch_size=1 , shuffle=False, num_workers=0)

        redun_list = []
        uniq_list = []
        syn_list = []
        mi_list = []
        cmi_list = []
        attn_list = []

        files = os.listdir(os.path.join(res_in_dir, f"pid_img_lvl/pid_{csv_name}"))
        idxs = [int(file[:-3].split("-")[-1]) for file in files]
        ctr = 0
        for batch_idx in sorted(idxs):
            ctr += 1
            pid_file = f"pid_img_lvl/pid_{csv_name}/{sdm_version}-pid-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-pid-{batch_idx}.pt"
            mi_file = f"mi_img_lvl/mi_{csv_name}/{sdm_version}-nll_1D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-mi-{batch_idx}.pt"
            cmi_file = f"cmi_img_lvl/cmi_{csv_name}/{sdm_version}-nll_1D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-cmi-{batch_idx}.pt"
            daam_file = f"daam/daam_{csv_name}/{sdm_version}-attnmaps-{csv_name}-{snr_num}-{batch_idx}.pt"

            print(batch_idx, pid_file)

            pid_path = os.path.join(res_in_dir, pid_file)
            mi_path = os.path.join(res_in_dir, mi_file)
            cmi_path = os.path.join(res_in_dir, cmi_file)
            daam_path = os.path.join(res_in_dir, daam_file)

            redun = t.load(pid_path)['pixel_redun']
            print("Loaded redundancy")
            syn = t.load(pid_path)['pixel_syn']
            print("Loaded synergy")
            uniq = t.load(pid_path)['pixel_uniq']
            print("Loaded uniqueness")

            mi = t.load(mi_path)['mi'][:, 0, ...]
            cmi = t.load(cmi_path)['mi'][:, 0, ...]
            attn = t.load(daam_path)['attnmaps'][:, ...]

            redun_list.append(redun)
            uniq_list.append(uniq)
            syn_list.append(syn)

            mi_list.append(mi)
            cmi_list.append(cmi)
            attn_list.append(attn)

        redun_list = t.cat(redun_list)
        uniq_list = t.cat(uniq_list)
        syn_list = t.cat(syn_list)

        mi_list = t.cat(mi_list)
        cmi_list = t.cat(cmi_list)
        attn_list = t.cat(attn_list)
        attn_list = attn_list.mean(dim = (-1, -2))

        redun_list = (redun_list - redun_list.min(dim = 0)[0]) / (redun_list.max(dim = 0)[0] - redun_list.min(dim = 0)[0])
        uniq_list = (uniq_list - uniq_list.min(dim = 0)[0]) / (uniq_list.max(dim = 0)[0] - uniq_list.min(dim = 0)[0])
        syn_list = (syn_list - syn_list.min(dim = 0)[0]) / (syn_list.max(dim = 0)[0] - syn_list.min(dim = 0)[0])

        mi_list = (mi_list - mi_list.min(dim = 0)[0]) / (mi_list.max(dim = 0)[0] - mi_list.min(dim = 0)[0])
        cmi_list = (cmi_list - cmi_list.min(dim = 0)[0]) / (cmi_list.max(dim = 0)[0] - cmi_list.min(dim = 0)[0])
        attn_list = (attn_list - attn_list.min(dim = 0)[0]) / (attn_list.max(dim = 0)[0] - attn_list.min(dim = 0)[0])

        # visualization
        print('Starting visualization...')
        out_path = os.path.join(fig_out_dir, f"{sdm_version}/{csv_name}/")
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        save_path = os.path.join(out_path, "values.txt")

        for ctr, batch in enumerate(dataloader):
            image_batch = batch['image']
            image = image_batch.permute(0, 2, 3, 1)[0]
            caption_batch = batch['caption']
            obj1_batch = batch['obj1']
            obj2_batch = batch['obj2']

            with open(save_path, "a+") as f:
                line = f"{caption_batch[0]},{obj1_batch[0]},{obj2_batch[0]},{redun_list[ctr]},{uniq_list[ctr][0]},{uniq_list[ctr][1]},{syn_list[ctr]},{mi_list[ctr * 2]},{mi_list[ctr * 2 + 1]},{cmi_list[ctr * 2]},{cmi_list[ctr * 2 + 1]},{attn_list[ctr * 2]},{attn_list[ctr * 2 + 1]}\n"
                f.write(line)

            print(f"Done with {ctr}")

        print('Done')
            
    elif visual_type == 'gender_bias':
        # load data
        csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
        dataset = utils.CustomDataset(csv_dir)
        dataloader = t.utils.data.DataLoader(dataset, batch_size=1 , shuffle=False, num_workers=0)

        redun_list = []
        uniq_list = []
        syn_list = []
        mi_list = []
        cmi_list = []
        attn_list = []

        csv_names = ["occupation_male", "occupation_female"]

        for csv_name in csv_names:
            files = os.listdir(os.path.join(res_in_dir, f"pid_img_lvl/pid_{csv_name}"))
            idxs = [int(file[:-3].split("-")[-1]) for file in files]
            ctr = 0

            for batch_idx in sorted(idxs):
                ctr += 1
                pid_file = f"pid_img_lvl/pid_{csv_name}/{sdm_version}-pid-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-pid-{batch_idx}.pt"
                mi_file = f"mi_img_lvl/mi_{csv_name}/{sdm_version}-nll_1D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-mi-{batch_idx}.pt"
                cmi_file = f"cmi_img_lvl/cmi_{csv_name}/{sdm_version}-nll_1D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-cmi-{batch_idx}.pt"
                daam_file = f"daam/daam_{csv_name}/{sdm_version}-attnmaps-{csv_name}-{snr_num}-{batch_idx}.pt"

                print(batch_idx, pid_file)

                pid_path = os.path.join(res_in_dir, pid_file)
                mi_path = os.path.join(res_in_dir, mi_file)
                cmi_path = os.path.join(res_in_dir, cmi_file)
                daam_path = os.path.join(res_in_dir, daam_file)

                redun = t.load(pid_path)['pixel_redun']
                print("Loaded redundancy")
                syn = t.load(pid_path)['pixel_syn']
                print("Loaded synergy")
                uniq = t.load(pid_path)['pixel_uniq']
                print("Loaded uniqueness")

                mi = t.load(mi_path)['mi'][:, 0, ...]
                cmi = t.load(cmi_path)['mi'][:, 0, ...]
                attn = t.load(daam_path)['attnmaps'][:, ...]

                redun_list.append(redun)
                uniq_list.append(uniq)
                syn_list.append(syn)

                mi_list.append(mi)
                cmi_list.append(cmi)
                attn_list.append(attn)

        redun_list = t.cat(redun_list)
        uniq_list = t.cat(uniq_list)
        syn_list = t.cat(syn_list)

        mi_list = t.cat(mi_list)
        cmi_list = t.cat(cmi_list)
        attn_list = t.cat(attn_list)
        attn_list = attn_list.mean(dim = (-1, -2))

        redun_list = (redun_list - redun_list.min(dim = 0)[0]) / (redun_list.max(dim = 0)[0] - redun_list.min(dim = 0)[0])
        uniq_list = (uniq_list - uniq_list.min(dim = 0)[0]) / (uniq_list.max(dim = 0)[0] - uniq_list.min(dim = 0)[0])
        syn_list = (syn_list - syn_list.min(dim = 0)[0]) / (syn_list.max(dim = 0)[0] - syn_list.min(dim = 0)[0])

        mi_list = (mi_list - mi_list.min(dim = 0)[0]) / (mi_list.max(dim = 0)[0] - mi_list.min(dim = 0)[0])
        cmi_list = (cmi_list - cmi_list.min(dim = 0)[0]) / (cmi_list.max(dim = 0)[0] - cmi_list.min(dim = 0)[0])
        attn_list = (attn_list - attn_list.min(dim = 0)[0]) / (attn_list.max(dim = 0)[0] - attn_list.min(dim = 0)[0])

        # visualization
        print('Starting visualization...')
        out_path = os.path.join(fig_out_dir, f"{sdm_version}/{csv_name}/")
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        save_path = os.path.join(out_path, "values.txt")
        num_samples = len(redun_list) // 2

        for ctr, batch in enumerate(dataloader):
            image_batch = batch['image']
            image = image_batch.permute(0, 2, 3, 1)[0]
            caption_batch = batch['caption']
            obj1_batch = batch['obj1']
            obj2_batch = batch['obj2']

            with open(save_path, "a+") as f:
                line = f"{caption_batch[0]},{obj1_batch[0]},{obj2_batch[0]},{redun_list[ctr]},{uniq_list[ctr][0]},{uniq_list[ctr][1]},{syn_list[ctr]},{mi_list[ctr * 2]},{mi_list[ctr * 2 + 1]},{cmi_list[ctr * 2]},{cmi_list[ctr * 2 + 1]},{attn_list[ctr * 2]},{attn_list[ctr * 2 + 1]}\n"
                f.write(line)
                line = f"{caption_batch[0]},Female,{obj2_batch[0]},{redun_list[num_samples + ctr]},{uniq_list[num_samples + ctr][0]},{uniq_list[num_samples + ctr][1]},{syn_list[num_samples + ctr]},{mi_list[num_samples * 2 + ctr * 2]},{mi_list[num_samples * 2 + ctr * 2 + 1]},{cmi_list[num_samples * 2 + ctr * 2]},{cmi_list[num_samples * 2 + ctr * 2 + 1]},{attn_list[num_samples * 2 + ctr * 2]},{attn_list[num_samples * 2 + ctr * 2 + 1]}\n"
                f.write(line)

            print(f"Done with {ctr}")

        print('Done')
            
    elif visual_type == 'redun':
        # load data
        csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
        dataset = utils.CustomDataset(csv_dir)
        dataloader = t.utils.data.DataLoader(dataset, batch_size=1 , shuffle=False, num_workers=0)

        redun_list = []
        uniq_list = []
        syn_list = []
        mi_list = []
        cmi_list = []
        attn_list = []

        files = os.listdir(os.path.join(res_in_dir, f"pid/pid_{csv_name}"))
        idxs = [int(file[:-3].split("-")[-1]) for file in files]
        ctr = 0
        for batch_idx in sorted(idxs):
            ctr += 1
            pid_file = f"pid/pid_{csv_name}/{sdm_version}-pid-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-pid-{batch_idx}.pt"
            mi_file = f"mi/mi_{csv_name}/{sdm_version}-nll_2D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-mi-{batch_idx}.pt"
            cmi_file = f"cmi/cmi_{csv_name}/{sdm_version}-nll_2D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-cmi-{batch_idx}.pt"
            daam_file = f"daam/daam_{csv_name}/{sdm_version}-attnmaps-{csv_name}-{snr_num}-{batch_idx}.pt"

            print(batch_idx, pid_file)

            pid_path = os.path.join(res_in_dir, pid_file)
            mi_path = os.path.join(res_in_dir, mi_file)
            cmi_path = os.path.join(res_in_dir, cmi_file)
            daam_path = os.path.join(res_in_dir, daam_file)

            redun = t.load(pid_path)['pixel_redun']
            print("Loaded redundancy")
            syn = t.load(pid_path)['pixel_syn']
            print("Loaded synergy")
            uniq = t.load(pid_path)['pixel_uniq']
            print("Loaded uniqueness")

            mi = t.load(mi_path)['pixel_mi'][:, 0, ...]
            cmi = t.load(cmi_path)['pixel_mi'][:, 0, ...]
            attn = t.load(daam_path)['attnmaps'][:, ...]

            redun_list.append(redun)
            uniq_list.append(uniq)
            syn_list.append(syn)

            mi_list.append(mi)
            cmi_list.append(cmi)
            attn_list.append(attn)

        redun_list = t.cat(redun_list)
        uniq_list = t.cat(uniq_list)
        syn_list = t.cat(syn_list)

        mi_list = t.cat(mi_list)
        cmi_list = t.cat(cmi_list)
        attn_list = t.cat(attn_list)

        # visualization
        print('Starting visualization...')
        out_path = os.path.join(fig_out_dir, f"{sdm_version}/{csv_name}/")
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        vis = {}
        for ctr, batch in enumerate(dataloader): 
            image_batch = batch['image']
            image = image_batch.permute(0, 2, 3, 1)[0]
            caption_batch = batch['caption']
            obj1_batch = batch['obj1']
            obj2_batch = batch['obj2']

            vis['redun'] = redun_list[ctr]
            vis['cmi'] = find_intersection(cmi_list[2 * ctr], cmi_list[2 * ctr + 1], caption_batch[0], obj1_batch[0], obj2_batch[0], "cmi")
            vis['mi'] = find_intersection(mi_list[2 * ctr], mi_list[2 * ctr + 1], caption_batch[0], obj1_batch[0], obj2_batch[0], "mi")
            vis['attn'] = find_intersection(attn_list[2 * ctr], attn_list[2 * ctr + 1], caption_batch[0], obj1_batch[0], obj2_batch[0], "attn")

            fig = plot_redun(image, vis, caption_batch, obj1_batch, obj2_batch, norm_over_all_data, norm_over_all_maps, exp, cmap)
            fig.savefig(os.path.join(out_path, f"{caption_batch[0]}_{obj1_batch[0]}_{obj2_batch[0]}.png"), dpi=300, bbox_inches='tight')
            # fig.savefig(os.path.join(out_path, f"{ctr}_{obj_batch[0]}.png"), dpi=300, bbox_inches='tight')
            plt.close(fig)

            print(f"Done with {ctr}")

        print('Done')

    elif visual_type == 'img_edit':
        # load data
        csv_dir = os.path.join(data_in_dir, f'{csv_name}.csv')
        dataset = utils.CustomDataset(csv_dir)
        dataloader = t.utils.data.DataLoader(dataset, batch_size=1 , shuffle=False, num_workers=0)

        redun_list = []
        uniq_list = []
        syn_list = []
        mi_list = []
        cmi_list = []
        attn_list = []
        orig_list = []
        mod_list = []

        img_lvl_uniq_list = []
        img_lvl_redun_list = []
        img_lvl_mi_list = []
        img_lvl_cmi_list = []
        img_lvl_attn_list = []

        files = os.listdir(os.path.join(res_in_dir, f"pid/pid_{csv_name}"))
        idxs = [int(file[:-3].split("-")[-1]) for file in files]
        ctr = 0
        for batch_idx in sorted(idxs):
            ctr += 1
            pid_file = f"pid/pid_{csv_name}/{sdm_version}-pid-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-pid-{batch_idx}.pt"
            mi_file = f"mi/mi_{csv_name}/{sdm_version}-nll_2D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-mi-{batch_idx}.pt"
            cmi_file = f"cmi/cmi_{csv_name}/{sdm_version}-nll_2D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-cmi-{batch_idx}.pt"
            daam_file = f"daam/daam_{csv_name}/{sdm_version}-attnmaps-{csv_name}-{snr_num}-{batch_idx}.pt"
            img_edit_file = f"img_edit/img_edit_{csv_name}/{sdm_version}-{csv_name}-{batch_idx}.pt"

            img_lvl_pid_file = f"pid_img_lvl/pid_{csv_name}/{sdm_version}-pid-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-pid-{batch_idx}.pt"
            img_lvl_cmi_file = f"cmi_img_lvl/cmi_{csv_name}/{sdm_version}-nll_1D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-cmi-{batch_idx}.pt"
            img_lvl_mi_file = f"mi_img_lvl/mi_{csv_name}/{sdm_version}-nll_1D-{csv_name}-{int_mode}-{z_sample_num}-{snr_num}-mi-{batch_idx}.pt"

            print(batch_idx, pid_file)

            pid_path = os.path.join(res_in_dir, pid_file)
            mi_path = os.path.join(res_in_dir, mi_file)
            cmi_path = os.path.join(res_in_dir, cmi_file)
            daam_path = os.path.join(res_in_dir, daam_file)
            img_edit_path = os.path.join(res_in_dir, img_edit_file)

            img_lvl_pid_path = os.path.join(res_in_dir, img_lvl_pid_file)
            img_lvl_cmi_path = os.path.join(res_in_dir, img_lvl_cmi_file)
            img_lvl_mi_path = os.path.join(res_in_dir, img_lvl_mi_file)

            redun = t.load(pid_path)['pixel_redun']
            print("Loaded redundancy")
            syn = t.load(pid_path)['pixel_syn']
            print("Loaded synergy")
            uniq = t.load(pid_path)['pixel_uniq']
            print("Loaded uniqueness")

            mi = t.load(mi_path)['pixel_mi'][:, 0, ...]
            cmi = t.load(cmi_path)['pixel_mi'][:, 0, ...]
            attn = t.load(daam_path)['attnmaps'][:, ...]

            img_lvl_redun = t.load(img_lvl_pid_path)['pixel_redun']
            img_lvl_uniq = t.load(img_lvl_pid_path)['pixel_uniq']
            img_lvl_mi = t.load(img_lvl_mi_path)['mi'][:, 0, ...]
            img_lvl_cmi = t.load(img_lvl_cmi_path)['mi'][:, 0, ...]

            orig = t.tensor(t.load(img_edit_path)['original'])
            mod = t.tensor(t.load(img_edit_path)['modified'])

            redun_list.append(redun)
            uniq_list.append(uniq)
            syn_list.append(syn)

            mi_list.append(mi)
            cmi_list.append(cmi)
            attn_list.append(attn)

            img_lvl_uniq_list.append(img_lvl_uniq)
            img_lvl_redun_list.append(img_lvl_redun)
            img_lvl_mi_list.append(img_lvl_mi)
            img_lvl_cmi_list.append(img_lvl_cmi)

            orig_list.append(orig)
            mod_list.append(mod)

            del mi, cmi, attn, redun, syn, uniq, orig, mod
            t.cuda.empty_cache()
            gc.collect()

        redun_list = t.cat(redun_list)
        uniq_list = t.cat(uniq_list)
        syn_list = t.cat(syn_list)

        mi_list = t.cat(mi_list)
        cmi_list = t.cat(cmi_list)
        attn_list = t.cat(attn_list)

        img_lvl_uniq_list = t.cat(img_lvl_uniq_list)
        img_lvl_redun_list = t.cat(img_lvl_redun_list)
        img_lvl_mi_list = t.cat(img_lvl_mi_list)
        img_lvl_cmi_list = t.cat(img_lvl_cmi_list)
        img_lvl_attn_list = attn_list.mean(dim = (-1, -2))

        orig_list = t.cat(orig_list)
        mod_list = t.cat(mod_list)

        pix_mses = []
        img_mses = []
        min_img_mses = []
        psnrs = []
        max_psnrs = []
        clip_scores = []

        redun_clip = []
        cmi_clip = []
        mi_clip = []
        attn_clip = []

        # visualization
        print('Starting visualization...')
        out_path = os.path.join(fig_out_dir, f"{sdm_version}/{csv_name}/")
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        transform = T.ToPILImage()

        vis = {}
        for ctr, batch in enumerate(dataloader):
            image_batch = batch['image']
            image = image_batch.permute(0, 2, 3, 1)[0]
            caption_batch = batch['caption']

            vis['orig'] = orig_list[ctr]
            vis['redun'] = redun_list[ctr]

            for i in range(2):
                obj_batch = batch[f'obj{i + 1}']
                vis['mod'] = mod_list[ctr, i]
                vis['uniq'] = uniq_list[ctr, i]
                vis['cmi'] = cmi_list[2 * ctr + i]
                vis['mi'] = mi_list[2 * ctr + i]
                vis['attn'] = attn_list[2 * ctr + i]

                # fig = plot_img_edit(vis, caption_batch, obj_batch, norm_over_all_data, norm_over_all_maps, exp, cmap)
                # # fig.savefig(os.path.join(out_path, f"{caption_batch[0]}_{obj_batch[0]}.png"), dpi=300, bbox_inches='tight')
                # fig.savefig(os.path.join(out_path, f"{ctr}_{obj_batch[0]}.png"), dpi=300, bbox_inches='tight')
                # plt.close(fig)

                pix_mse = np.square(np.array(vis['orig']) - np.array(vis['mod'])).sum(axis = -1).squeeze().flatten()
                img_mse = pix_mse.sum()
                psnr = cv2.PSNR(np.array(vis['orig']), np.array(vis['mod']))
                with t.no_grad():
                    clip_score = clip_similarity(transform(vis['orig'].permute(2, 0, 1)), transform(vis['mod'].permute(2, 0, 1)))

                pix_mses.append(pix_mse)
                img_mses.append(img_mse)
                psnrs.append(psnr)
                clip_scores.append(clip_score)

                print(caption_batch, batch[f'obj{i + 1}'])
                print(img_mse, img_lvl_uniq_list[ctr, i])
                print()

                vis['cmi_mse_pix'] = np.corrcoef(vis['cmi'].flatten().numpy(), pix_mse)[0, 1]
                vis['mi_mse_pix'] = np.corrcoef(vis['mi'].flatten().numpy(), pix_mse)[0, 1]
                vis['attn_mse_pix'] = np.corrcoef(vis['attn'].flatten().numpy(), pix_mse)[0, 1]
                vis['redun_mse_pix'] = np.corrcoef(vis['redun'].flatten().numpy(), pix_mse)[0, 1]
                vis['uniq_mse_pix'] = np.corrcoef(vis['uniq'].flatten().numpy(), pix_mse)[0, 1]

            min_img_mses.append(min(img_mses[-2:]))
            redun_clip += get_clip_sim(image, extract_pid_mask(vis['redun']), batch['obj1'], batch['obj2'])
            cmi_clip += get_clip_sim(image, extract_intersect_mask(cmi_list[2 * ctr], cmi_list[2 * ctr + 1]), batch['obj1'], batch['obj2'])
            mi_clip += get_clip_sim(image, extract_intersect_mask(mi_list[2 * ctr], mi_list[2 * ctr + 1]), batch['obj1'], batch['obj2'])
            attn_clip += get_clip_sim(image, extract_intersect_mask(attn_list[2 * ctr], attn_list[2 * ctr + 1]), batch['obj1'], batch['obj2'])

            print(f"Done with {ctr}")

        min_img_mses = np.array(min_img_mses)
        img_mses = np.array(img_mses)
        psnrs = np.array(psnrs)
        clip_scores = np.array(clip_scores)

        redun_clip = np.array(redun_clip)
        cmi_clip = np.array(cmi_clip)
        mi_clip = np.array(mi_clip)
        attn_clip = np.array(attn_clip)

        import pdb; pdb.set_trace()

        res = {
            'redun_min_mse_img': np.corrcoef(img_lvl_redun_list.numpy(), min_img_mses)[0, 1],
            'redun_clip_mse_img': np.corrcoef(redun_clip, img_mses)[0, 1],
            'cmi_clip_mse_img': np.corrcoef(cmi_clip, img_mses)[0, 1],
            'mi_clip_mse_img': np.corrcoef(mi_clip, img_mses)[0, 1],
            'attn_clip_mse_img': np.corrcoef(attn_clip, img_mses)[0, 1],

            'cmi_mse_img': np.corrcoef(img_lvl_cmi_list.numpy(), img_mses)[0, 1],
            'mi_mse_img': np.corrcoef(img_lvl_mi_list.numpy(), img_mses)[0, 1],
            'attn_mse_img': np.corrcoef(img_lvl_attn_list.numpy(), img_mses)[0, 1],
            'redun_mse_img': np.corrcoef(img_lvl_redun_list.repeat_interleave(2).numpy(), img_mses)[0, 1],
            'uniq_mse_img': np.corrcoef(img_lvl_uniq_list.flatten().numpy(), img_mses)[0, 1],

            'cmi_psnr_img': np.corrcoef(img_lvl_cmi_list.numpy(), psnrs)[0, 1], 'mi_psnr_img': np.corrcoef(img_lvl_mi_list.numpy(), psnrs)[0, 1],
            'attn_psnr_img': np.corrcoef(img_lvl_attn_list.numpy(), psnrs)[0, 1],
            'redun_psnr_img': np.corrcoef(img_lvl_redun_list.repeat_interleave(2).numpy(), psnrs)[0, 1],
            'uniq_psnr_img': np.corrcoef(img_lvl_uniq_list.flatten().numpy(), psnrs)[0, 1],

            'cmi_clip_img': np.corrcoef(img_lvl_cmi_list.numpy(), clip_scores)[0, 1],
            'mi_clip_img': np.corrcoef(img_lvl_mi_list.numpy(), clip_scores)[0, 1],
            'attn_clip_img': np.corrcoef(img_lvl_attn_list.numpy(), clip_scores)[0, 1],
            'redun_clip_img': np.corrcoef(img_lvl_redun_list.repeat_interleave(2).numpy(), clip_scores)[0, 1],
            'uniq_clip_img': np.corrcoef(img_lvl_uniq_list.flatten().numpy(), clip_scores)[0, 1],

            'cmi_mse_pix': vis['cmi_mse_pix'],
            'mi_mse_pix': vis['mi_mse_pix'],
            'attn_mse_pix': vis['attn_mse_pix'],
            'redun_mse_pix': vis['redun_mse_pix'],
            'uniq_mse_pix': vis['uniq_mse_pix']
        }

        out_file_path = os.path.join(fig_out_dir, f"{sdm_version}/{csv_name}/corr.json")
        with open(out_file_path, "w") as outfile: 
            json.dump(res, outfile)

        print('Done')


if __name__ == "__main__":
    main()


