'''Code for VLM analysis'''
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import gc
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
from matplotlib.patches import Rectangle
from scipy import stats


def stat(arr1, arr2, name1, name2):
    pr = stats.pearsonr(arr1, arr2)
    sp = stats.spearmanr(arr1, arr2)
    print(
        f'Pearsonr between {name1} and {name2}: value = {round(pr[0].item(), 6)}, p = {round(pr[1].item(), 6)}')  # type: ignore
    print(
        f'Spearmanr between {name1} and {name2}: value = {round(sp[0].item(), 6)}, p = {round(sp[1].item(), 6)}')  # type: ignore


def correlation(info):
    print('====== Starting Correlation ======')
    ls = ['grad_in', 'grad_out', 'loss']
    for i, name1 in enumerate(ls):
        for j, name2 in enumerate(ls[i+1:], i+1):
            stat(info[:, i], info[:, j], name1, name2)
    stat(info[:, 1]-info[:, 0], info[:, 2], 'grad out - grad in', 'loss')
    print('====== Ending Correlation ======')


def get_loss(outputs, target_token_id):
    logits = outputs.logits[:, -1, :]
    loss = F.cross_entropy(logits, torch.tensor(
        target_token_id, device=logits.device))
    return loss


def get_saliency_matrix(outputs, target_token_id) -> torch.Tensor:
    """Get the saliency matrix from the model outputs. saliency: (layer_num, head_num, S, S)"""
    # **Retain** gradients on each attention‐map tensor
    loss = get_loss(outputs, target_token_id)

    for att in outputs.attentions:  # type: ignore
        att.retain_grad()
    loss.backward()

    # attn & attn_grad shape after stacking: (L, H, S, S). In our case (12, 12, S, S)
    attn_grad = torch.stack(
        [att.grad for att in outputs.attentions]).squeeze()  # type: ignore
    attn = torch.stack(outputs.attentions).squeeze()
    saliency = torch.abs(attn * attn_grad)
    del attn, attn_grad, loss, outputs
    torch.cuda.empty_cache()
    gc.collect()

    return saliency.detach().cpu()


def parse_image_id(image_path: str) -> tuple:
    """
    Extracts the image ID from the image path.
    input: /inference/images/ambulance_-m-012n7d_inpaint/0cda1863ee21cd0e.jpg
    output: (ImageID: 0a02c648d24f39fb, LabelNameL: /m/0mkg)
    """
    image_id = image_path.split('/')[-1].split('.')[0]
    label_name = image_path.split('/')[-2]
    label_name = '_'.join(label_name.split('_')[1:]).replace('-', '/')
    return image_id, label_name


def bbox_areas(df: pd.DataFrame, image_path: str):
    """
    Return the bounding box area for a given ImageID and LabelName.
    """
    image_id, label_name = parse_image_id(image_path)
    filtered = df[(df['ImageID'] == image_id) & (
        df['LabelName'] == label_name)][['XMin', 'YMin', 'XMax', 'YMax']]
    if filtered is None or not len(filtered):
        return []
    return filtered.values.tolist()


def extract_all_bbox(df: pd.DataFrame) -> dict:
    df_bbox = df[['ImageID', 'XMin', 'YMin', 'XMax', 'YMax']]
    df_bbox['bbox'] = df_bbox[['XMin', 'YMin', 'XMax', 'YMax']].values.tolist()
    bbox_dict = df_bbox.groupby('ImageID')['bbox'].apply(list).to_dict()
    return bbox_dict


def single_bbox_mask(bbox: list, num_patches: int, transpose=True):
    XMin, YMin, XMax, YMax = bbox
    # get the coordinates of patches
    coords = torch.linspace(0, 1, num_patches + 1)  # 17 values
    x0, x1 = coords[:-1], coords[1:]  # 16 each
    y0, y1 = coords[:-1], coords[1:]

    # Create 2D patch grids，every patch has [xmin, ymin, xmax, ymax]
    patch_xmin, patch_ymin = torch.meshgrid(x0, y0, indexing='ij')
    patch_xmax, patch_ymax = torch.meshgrid(x1, y1, indexing='ij')
    patches = torch.stack(  # stack to (16, 16, 4)
        [patch_xmin, patch_ymin, patch_xmax, patch_ymax], dim=-1)

    # completely in：patch >= bbox_min & patch <= bbox_max
    fully_inside_mask = ((patches[..., 0] >= XMin) & (patches[..., 1] >= YMin) &
                         (patches[..., 2] <= XMax) & (patches[..., 3] <= YMax))
    if transpose:
        fully_inside_mask = torch.transpose(fully_inside_mask, 0, 1)
    return fully_inside_mask


def bbox_mask(image_path: str, df: pd.DataFrame, num_patches=16, ravel=True, maskout: None | list = None, transpose=True) -> tuple:  # type: ignore
    """
    Get bounding box areas for a given image path wrt patch.
    maskout: If None, mask_outside is just opposite to mask_inside.
    Otherwise, maskout should be a list of bboxes, meaning that we want
    the pixels outside all bboxes.
    """
    bbox_areas_list = bbox_areas(df, image_path)
    mask_inside = torch.zeros((num_patches, num_patches), dtype=torch.bool)
    for bbox in bbox_areas_list:
        fully_inside_mask = single_bbox_mask(
            bbox, num_patches, transpose=transpose)
        mask_inside |= fully_inside_mask

    if type(maskout) is list:
        any_inside = torch.zeros((num_patches, num_patches), dtype=torch.bool)
        for bbox in maskout:
            fully_inside = single_bbox_mask(
                bbox, num_patches, transpose=transpose)
            any_inside |= fully_inside
        mask_outside = ~any_inside
    elif maskout is None:
        mask_outside = ~mask_inside
    else:
        raise ValueError('maskout should be list or None')

    if ravel:
        return mask_inside.ravel(), mask_outside.ravel()
    else:
        return mask_inside, mask_outside


def plot_attention(attn_dict, title='Attention', folder_name='attn_result', pic_name=None, flag=0):
    """
    Plot the attention flow with respect to the layer (1-12).
    attn_ls = {'other': (other, other_std), 'in': (patch_in_bbox,
                   patch_in_bbox_std), 'txt': (txt_attn, txt_attn_std)}
    Note: the std is actually 'standard error'
    """
    color = ['r', 'b', 'g']
    nrows, ncols = 2, 1
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 6, nrows * 4))

    ax1, ax2 = axes
    for j, ax in enumerate(axes):
        for idx, (k, val) in enumerate(attn_dict.items()):
            if k == 'txt' and j == 1:  # the second fig has no txt
                continue
            if flag:
                print(f"Plotting {k} on subplot {j}")
            avg, std = val
            avg = avg.detach().cpu().float().numpy()
            assert avg is not None
            ax.plot(avg, label=k, marker='+', color=color[idx])
            std = std.detach().cpu().float().numpy()
            ax.fill_between(np.arange(avg.size), avg - std,
                            avg + std, color=color[idx], alpha=0.15)
        ax.set_xlabel('Layer')
        ax.set_ylabel('Averaged Attention')
        ax.legend()
    ax1.set_title('With text', fontsize=10)
    ax2.set_title('Without text', fontsize=10)

    plt.suptitle(title, fontsize=14)
    Path(folder_name).mkdir(parents=True, exist_ok=True)
    plt.savefig(f'{folder_name}/{pic_name}.png')
    plt.clf()
    plt.close()


def plot_bbox(bbox_ls, axs, title):
    for bbox in bbox_ls:
        x_min, y_min, x_max, y_max = bbox
        width = x_max - x_min
        height = y_max - y_min
        rect2 = Rectangle((x_min, y_min), width, height,
                          linewidth=2, edgecolor='white', facecolor='none')
        rect1 = Rectangle((x_min, y_min), width, height,
                          linewidth=2, edgecolor='r', facecolor='none')
        axs[2].add_patch(rect2)
        axs[1].add_patch(rect1)
    axs[2].set_title(title)


def plot_seg(seg_mask, axs, title):
    for i in [1, 2]:
        axs[i].imshow(
            seg_mask,
            cmap='binary',
            alpha=0.5,
            interpolation='none',
            extent=[0, 224, 224, 0]
        )
    axs[2].set_title(title)


def grad_heatmap(grad: torch.Tensor, bbox_ls: list | None, orig_image, processed_image, title='Heatmap of Gradient', folder_name='heatmap_per_content', pic_name=None, plot_option='heatmap', seg_mask=None):
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    axs[0].imshow(orig_image)
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    axs[1].imshow(processed_image)
    axs[1].set_title("Processed Image")
    axs[1].axis('off')

    if plot_option == 'heatmap':
        plot_bbox(bbox_ls, axs, title)
    elif plot_option == 'segmentation':
        assert seg_mask is not None
        plot_seg(seg_mask, axs, title)

    mean = torch.mean(grad).item()
    std = torch.std(grad).item()
    vmin = max(mean - 3 * std, 0)
    vmax = mean + 3 * std
    norm = TwoSlopeNorm(vmin=vmin, vcenter=(vmin+vmax)/2, vmax=vmax)
    sns.heatmap(grad, cmap="bwr", norm=norm, ax=axs[2], alpha=0.5, cbar=True)

    plt.tight_layout()
    (Path(folder_name)/plot_option).mkdir(parents=True, exist_ok=True)
    plt.savefig(f'{folder_name}/{plot_option}/{pic_name}.png')
    plt.clf()
    plt.close()
