import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode

def plot_all_together(
    model_name, orig_img_t, adv_img_t,
    cv2_img,
    sim_map_clean, sim_map_adv,
    sim_map_clean_fs, sim_map_adv_fs,
    all_texts, target_texts,
    src_model=None, tgt_model=None,
    topk_ratio=0.1
):


    import matplotlib.pyplot as plt
    import matplotlib as mpl
    import numpy as np
    import cv2
    import torch

    def denorm_clip(img_1x3HW):
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=img_1x3HW.device).view(1,3,1,1)
        std  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=img_1x3HW.device).view(1,3,1,1)
        return (img_1x3HW * std + mean).clamp(0,1)

    def blend_map(sim_map, img):
        # local normalization (like before)
        H, W = img.shape[:2]
        sim_map_norm = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-6)
        sim_map_resized = cv2.resize(sim_map_norm, (W, H))
        vis = (sim_map_resized * 255).astype('uint8')
        vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
        img_disp = (img*255).astype('uint8') if img.max() <= 1.0 else img.astype('uint8')
        return cv2.cvtColor(cv2.addWeighted(img_disp, 0.4, vis, 0.6, 0), cv2.COLOR_BGR2RGB)

    def topk_mask(sim_map, ratio=0.1):
        flat = sim_map.reshape(-1)
        k = max(1, int(ratio * flat.size))
        thr = np.partition(flat, -k)[-k]
        return (sim_map >= thr).astype(np.uint8)

    # --- denorm images ---
    orig_img = denorm_clip(orig_img_t).squeeze(0).permute(1,2,0).cpu().numpy()
    adv_img  = denorm_clip(adv_img_t).squeeze(0).permute(1,2,0).cpu().numpy()
    H, W = cv2_img.shape[:2]
    orig_resized = cv2.resize((orig_img*255).astype("uint8"), (W,H))
    adv_resized  = cv2.resize((adv_img*255).astype("uint8"), (W,H))

    n_targets = len(target_texts)
    fig, axs = plt.subplots(
        n_targets, 9,   # now 9 panels
        figsize=(22, 1.6*n_targets),
        gridspec_kw={'hspace': 0.05, 'wspace': 0.01}
    )

    if n_targets == 1:
        axs = [axs]  # wrap single row

    panel_titles = [
        "Original Image", "Adversarial Image (X-Shift)",
        "CLIP Clean Heatmap", "CLIP Attack Heatmap","Diff Overlay (Adv - Clean)",
        "FaithShield I Clean Heatmap", "FaithShield I Adversarial Heatmap",
        "FaithShield II Clean", "FaithShield II Adversarial",
        
    ]

    for row, target_text in enumerate(target_texts):
        idx = all_texts.index(target_text)

        sim_map_clean_np = sim_map_clean[0, :, :, idx].detach().cpu().numpy()
        sim_map_adv_np   = sim_map_adv[0, :, :, idx].detach().cpu().numpy()
        map_clean_fs     = sim_map_clean_fs[0, :, :, idx].detach().cpu().numpy()
        map_adv_fs       = sim_map_adv_fs[0, :, :, idx].detach().cpu().numpy()

        # --- Stage II masks ---
        mask_clean = topk_mask(cv2.resize(sim_map_clean_np, (W,H)), ratio=topk_ratio)
        mask_adv   = topk_mask(cv2.resize(sim_map_adv_np, (W,H)), ratio=topk_ratio)
        masked_clean = orig_resized.copy().astype(float); masked_clean[mask_clean==1] = [255,0,0]
        masked_adv   = adv_resized.copy().astype(float);   masked_adv[mask_adv==1]   = [255,0,0]

        # --- Difference map ---
        diff_map = sim_map_clean_np - sim_map_adv_np
        diff_overlay = blend_map(diff_map, cv2_img)

        panels = [
            orig_resized,
            adv_resized,
            blend_map(sim_map_clean_np, cv2_img),
            blend_map(sim_map_adv_np,   cv2_img),
            diff_overlay,
            blend_map(map_clean_fs,     cv2_img),
            blend_map(map_adv_fs,       cv2_img),
            masked_clean.astype('uint8'),
            masked_adv.astype('uint8')
            
        ]

        for col, (ax, img, title) in enumerate(zip(axs[row], panels, panel_titles)):
            ax.imshow(img)
            ax.axis("off")
            if row == 0:
                ax.set_title(title, fontsize=9)

        axs[row][0].annotate(
            f"{model_name}, Label:{target_text}",
            xy=(-0.08, 0.5), xycoords='axes fraction',
            fontsize=9, rotation=90, va='center', ha='center'
        )


    # adjust margins to leave space for colorbar
    # adjust margins to leave space for colorbar + labels
    fig.subplots_adjust(
        top=0.98, bottom=0.02,
        left=0.10, right=0.88,   # leave more space on right
        hspace=0.05, wspace=0.01
    )
    # ---- Add title showing transfer direction ----
    # ---- Add in-figure title directly on top of images ----
    if src_model and tgt_model:
        transfer_title = f"Adversarial Transfer: δ({src_model}) → {tgt_model}"
    else:
        transfer_title = f"Model: {model_name}"
    
    # place title inside the top-left subplot area
    axs[0][0].text(
        4.5, 1.3, transfer_title,
        fontsize=9, fontweight='normal',
        color='black',
        ha='center', va='center',
        transform=axs[0][0].transAxes,
        bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=4)
    )

    # ---- One shared vertical colorbar ----
    cmap = plt.get_cmap("jet")
    norm = mpl.colors.Normalize(vmin=0.0, vmax=1.0)
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    
    # Thinner and placed just before the right labels
    # place colorbar outside of labels
    # colorbar slightly left of text
    cax = fig.add_axes([0.885, 0.15, 0.004, 0.7])
    cbar = fig.colorbar(sm, cax=cax)
    cbar.ax.tick_params(labelsize=8)
    # cbar.set_label("Normalized intensity", fontsize=9)


    plt.show()


def plot_xai_transferability_fs(
    model_name,
    cv2_img,
    sim_map_clean, sim_map_adv,
    sim_map_clean_fs, sim_map_adv_fs,
    all_texts, target_texts,
    topk_ratio=0.1
    ):

    import matplotlib.pyplot as plt
    import matplotlib as mpl
    import numpy as np
    import cv2
    import torch

    def blend_map(sim_map, img):
        H, W = img.shape[:2]
        sim_map_norm = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-6)
        sim_map_resized = cv2.resize(sim_map_norm, (W, H))
        vis = (sim_map_resized * 255).astype('uint8')
        vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
        return cv2.cvtColor(cv2.addWeighted(img, 0.4, vis, 0.6, 0), cv2.COLOR_BGR2RGB)

    def topk_mask(sim_map, ratio=0.1):
        flat = sim_map.reshape(-1)
        k = max(1, int(ratio * flat.size))
        thr = np.partition(flat, -k)[-k]
        return (sim_map >= thr).astype(np.uint8)

    H, W = cv2_img.shape[:2]

    panel_titles = [
        "CLIP Clean Heatmap",
        "CLIP Attack Heatmap",
        "Diff Overlay (Adv – Clean)",
        "FaithShield-I Clean",
        "FaithShield-I Attack"
    ]

    n_targets = len(target_texts)
    fig, axs = plt.subplots(
        n_targets, 5,
        figsize=(20, 2.0 * n_targets),
        gridspec_kw={'hspace': 0.05, 'wspace': 0.01}
    )

    if n_targets == 1:
        axs = [axs]

    for row, target_text in enumerate(target_texts):
        idx = all_texts.index(target_text)

        # Extract the heatmaps for this target
        sm_clean = sim_map_clean[0, :, :, idx].detach().cpu().numpy()
        sm_adv   = sim_map_adv[0, :, :, idx].detach().cpu().numpy()
        fs_clean = sim_map_clean_fs[0, :, :, idx].detach().cpu().numpy()
        fs_adv   = sim_map_adv_fs[0, :, :, idx].detach().cpu().numpy()

        # Difference map
        diff = sm_adv - sm_clean

        # Images to display
        panels = [
            blend_map(sm_clean, cv2_img),
            blend_map(sm_adv,   cv2_img),
            blend_map(diff,     cv2_img),
            blend_map(fs_clean, cv2_img),
            blend_map(fs_adv,   cv2_img)
        ]

        for col, (ax, img) in enumerate(zip(axs[row], panels)):
            ax.imshow(img)
            ax.axis("off")
            if row == 0:
                ax.set_title(panel_titles[col], fontsize=10)

        # Label the row
        axs[row][0].annotate(
            f"Target: {target_text}",
            xy=(-0.15, 0.5), xycoords='axes fraction',
            fontsize=9, rotation=90, va='center', ha='center'
        )

    # Adjust layout
    fig.subplots_adjust(top=0.97, bottom=0.04, left=0.08, right=0.95)

    # Shared colorbar
    cmap = plt.get_cmap("jet")
    norm = mpl.colors.Normalize(vmin=0.0, vmax=1.0)
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    cax = fig.add_axes([0.96, 0.2, 0.01, 0.6])
    cbar = plt.colorbar(sm, cax=cax)
    cbar.ax.tick_params(labelsize=8)

    plt.show()
