import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode

def plot_all_together(
    model_name, orig_img_t, adv_img_t,
    cv2_img,
    sim_map_clean, sim_map_adv,
    sim_map_clean_fs, sim_map_adv_fs,
    all_texts, target_texts,
    src_model=None, tgt_model=None,
    topk_ratio=0.1
):


    import matplotlib.pyplot as plt
    import matplotlib as mpl
    import numpy as np
    import cv2
    import torch

    def denorm_clip(img_1x3HW):
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=img_1x3HW.device).view(1,3,1,1)
        std  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=img_1x3HW.device).view(1,3,1,1)
        return (img_1x3HW * std + mean).clamp(0,1)

    def blend_map(sim_map, img):
        # local normalization (like before)
        H, W = img.shape[:2]
        sim_map_norm = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-6)
        sim_map_resized = cv2.resize(sim_map_norm, (W, H))
        vis = (sim_map_resized * 255).astype('uint8')
        vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
        img_disp = (img*255).astype('uint8') if img.max() <= 1.0 else img.astype('uint8')
        return cv2.cvtColor(cv2.addWeighted(img_disp, 0.4, vis, 0.6, 0), cv2.COLOR_BGR2RGB)

    def topk_mask(sim_map, ratio=0.1):
        flat = sim_map.reshape(-1)
        k = max(1, int(ratio * flat.size))
        thr = np.partition(flat, -k)[-k]
        return (sim_map >= thr).astype(np.uint8)

    # --- denorm images ---
    orig_img = denorm_clip(orig_img_t).squeeze(0).permute(1,2,0).cpu().numpy()
    adv_img  = denorm_clip(adv_img_t).squeeze(0).permute(1,2,0).cpu().numpy()
    H, W = cv2_img.shape[:2]
    orig_resized = cv2.resize((orig_img*255).astype("uint8"), (W,H))
    adv_resized  = cv2.resize((adv_img*255).astype("uint8"), (W,H))

    n_targets = len(target_texts)
    fig, axs = plt.subplots(
        n_targets, 9,   # now 9 panels
        figsize=(22, 1.6*n_targets),
        gridspec_kw={'hspace': 0.05, 'wspace': 0.01}
    )

    if n_targets == 1:
        axs = [axs]  # wrap single row

    panel_titles = [
        "Original Image", "Adversarial Image (X-Shift)",
        "CLIP Clean", "CLIP Attack","Diff Overlay (Adv - Clean)",
        "FaithShield I Clean", "FaithShield I Adversarial",
        "FaithShield II Clean", "FaithShield II Adversarial",
        
    ]

    for row, target_text in enumerate(target_texts):
        idx = all_texts.index(target_text)

        sim_map_clean_np = sim_map_clean[0, :, :, idx].detach().cpu().numpy()
        sim_map_adv_np   = sim_map_adv[0, :, :, idx].detach().cpu().numpy()
        map_clean_fs     = sim_map_clean_fs[0, :, :, idx].detach().cpu().numpy()
        map_adv_fs       = sim_map_adv_fs[0, :, :, idx].detach().cpu().numpy()

        # --- Stage II masks ---
        mask_clean = topk_mask(cv2.resize(sim_map_clean_np, (W,H)), ratio=topk_ratio)
        mask_adv   = topk_mask(cv2.resize(sim_map_adv_np, (W,H)), ratio=topk_ratio)
        masked_clean = orig_resized.copy().astype(float); masked_clean[mask_clean==1] = [255,0,0]
        masked_adv   = adv_resized.copy().astype(float);   masked_adv[mask_adv==1]   = [255,0,0]

        # --- Difference map ---
        diff_map = sim_map_adv_np - sim_map_clean_np
        diff_overlay = blend_map(diff_map, cv2_img)

        panels = [
            orig_resized,
            adv_resized,
            blend_map(sim_map_clean_np, cv2_img),
            blend_map(sim_map_adv_np,   cv2_img),
            diff_overlay,
            blend_map(map_clean_fs,     cv2_img),
            blend_map(map_adv_fs,       cv2_img),
            masked_clean.astype('uint8'),
            masked_adv.astype('uint8')
            
        ]

        for col, (ax, img, title) in enumerate(zip(axs[row], panels, panel_titles)):
            ax.imshow(img)
            ax.axis("off")
            if row == 0:
                ax.set_title(title, fontsize=9)

        axs[row][0].annotate(
            f"{model_name}, Label:{target_text}",
            xy=(-0.08, 0.5), xycoords='axes fraction',
            fontsize=9, rotation=90, va='center', ha='center'
        )


    # adjust margins to leave space for colorbar
    # adjust margins to leave space for colorbar + labels
    fig.subplots_adjust(
        top=0.98, bottom=0.02,
        left=0.10, right=0.88,   # leave more space on right
        hspace=0.05, wspace=0.01
    )
    # ---- Add title showing transfer direction ----
    # ---- Add in-figure title directly on top of images ----
    if src_model and tgt_model:
        transfer_title = f"Adversarial Transfer: δ({src_model}) → {tgt_model}"
    else:
        transfer_title = None
    
    # place title inside the top-left subplot area
    axs[0][0].text(
        4.5, 1.3, transfer_title,
        fontsize=9, fontweight='normal',
        color='black',
        ha='center', va='center',
        transform=axs[0][0].transAxes,
        bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=4)
    )

    # ---- One shared vertical colorbar ----
    cmap = plt.get_cmap("jet")
    norm = mpl.colors.Normalize(vmin=0.0, vmax=1.0)
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    
    # Thinner and placed just before the right labels
    # place colorbar outside of labels
    # colorbar slightly left of text
    cax = fig.add_axes([0.885, 0.15, 0.004, 0.7])
    cbar = fig.colorbar(sm, cax=cax)
    cbar.ax.tick_params(labelsize=8)
    # cbar.set_label("Normalized intensity", fontsize=9)


    plt.show()




def plot_xai_transferability_fs_extended(
    model,
    model_surgery,
    model_name,
    cv2_img,
    image_clean,
    image_adv,
    sim_map_clean, sim_map_adv,
    sim_map_clean_fs, sim_map_adv_fs,
    text_features,
    all_texts, target_texts,
    topk_ratio=0.1
    ):

    import matplotlib.pyplot as plt
    import matplotlib as mpl
    import numpy as np
    import cv2
    import torch

    #######################################################################
    # Helper: Heatmap blending
    #######################################################################
    def blend_map(sim_map, img):
        H, W = img.shape[:2]
        sim_map_norm = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-6)
        sim_map_resized = cv2.resize(sim_map_norm, (W, H))
        vis = (sim_map_resized * 255).astype('uint8')
        vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
        return cv2.cvtColor(cv2.addWeighted(img, 0.4, vis, 0.6, 0), cv2.COLOR_BGR2RGB)

    #######################################################################
    # Attach the XAI methods given above
    #######################################################################
    def get_patch_tokens(model, image):
        with torch.no_grad():
            z = model.encode_image(image)
            z = z[:, 1:, :]  
        return z

    def xai_scorecam(model, image, text_features, target_idx):
        z = get_patch_tokens(model, image)[0]
        t = text_features[target_idx]
        sim = (z @ t).cpu().numpy()
        sim = (sim - sim.min()) / (sim.max() + 1e-8)
        hw = int(np.sqrt(sim.shape[0]))
        return sim.reshape(hw, hw)

    def xai_rise(model, image, text_features, target_idx, N=800, p=0.5):
        _, _, H, W = image.shape
        grid = int(H / 16)
        heat = np.zeros((grid, grid))
        t = text_features[target_idx]

        for _ in range(N):
            mask_small = (np.random.rand(grid, grid) < p).astype(np.float32)
            mask = cv2.resize(mask_small, (H, W), interpolation=cv2.INTER_NEAREST)
            masked = image * torch.tensor(mask).unsqueeze(0).to(image.device)

            with torch.no_grad():
                z = get_patch_tokens(model, masked)[0]
                score = (z @ t).cpu().numpy()

            heat += score.reshape(grid, grid)

        heat = (heat - heat.min()) / (heat.max() + 1e-8)
        return heat

    def xai_gae(model, image, text_features, target_idx):
        image2 = image.clone().requires_grad_(True)
        z = model.encode_image(image2)
        cls = z[:, 0, :]
        t = text_features[target_idx].unsqueeze(0)
        loss = (cls @ t.T).sum()
        loss.backward()

        grad = image2.grad.squeeze().permute(1,2,0).cpu().numpy()
        grad = np.abs(grad).mean(axis=2)
        grad = (grad - grad.min()) / (grad.max() + 1e-8)
        return cv2.resize(grad, (14, 14))

    def xai_clip_sim(model, image, text_features, target_idx):
        with torch.no_grad():
            z = model.encode_image(image)
            z = z / (z.norm(dim=-1, keepdim=True) + 1e-6)
            sim = (z @ text_features.T)
            Hm = Wm = int(np.sqrt(sim.shape[1] - 1))
            return sim[:, 1:, target_idx].reshape(Hm, Wm).cpu().detach().numpy()

    def xai_clip_surgery(model_surgery, image, text_features, target_idx):
        with torch.no_grad():
            z = model_surgery.encode_image(image)
            z = z / (z.norm(dim=-1, keepdim=True) + 1e-6)
            sim = (z @ text_features.T)
            Hm = Wm = int(np.sqrt(sim.shape[1] - 1))
            return sim[:, 1:, target_idx].reshape(Hm, Wm).cpu().detach().numpy()

    #######################################################################
    # XAI METHODS TO PLOT
    #######################################################################
    XAI_METHODS = [
        ("CLIP Clean Heatmap", None),        # placeholder – handled separately
        ("CLIP Attack Heatmap", None),
        ("Diff Overlay (Adv – Clean)", None),
        ("FaithShield-I Clean", None),
        ("FaithShield-I Attack", None),
        ("ScoreCAM Clean", lambda img: xai_scorecam(model, img, text_features, target_idx)),
        ("ScoreCAM Attack", lambda img: xai_scorecam(model, img, text_features, target_idx)),
        ("RISE Clean", lambda img: xai_rise(model, img, text_features, target_idx)),
        ("RISE Attack", lambda img: xai_rise(model, img, text_features, target_idx)),
        ("GAE Clean", lambda img: xai_gae(model, img, text_features, target_idx)),
        ("GAE Attack", lambda img: xai_gae(model, img, text_features, target_idx)),
    ]

    #######################################################################
    # Plot arrangement
    #######################################################################
    num_cols = len(XAI_METHODS)
    n_targets = len(target_texts)

    fig, axs = plt.subplots(
        n_targets, num_cols,
        figsize=(4*num_cols, 3*n_targets),
        gridspec_kw={'hspace': 0.1, 'wspace': 0.05}
    )
    if n_targets == 1:
        axs = [axs]

    #######################################################################
    # For each target: extract all heatmaps & plot
    #######################################################################
    for row, target_text in enumerate(target_texts):
        target_idx = all_texts.index(target_text)

        # CLIP base maps
        sm_clean = sim_map_clean[0, :, :, target_idx].detach().cpu().numpy()
        sm_adv   = sim_map_adv[0, :, :, target_idx].detach().cpu().numpy()
        diff     = sm_adv - sm_clean

        # FS I maps
        fs_clean = sim_map_clean_fs[0, :, :, target_idx].detach().cpu().numpy()
        fs_adv   = sim_map_adv_fs[0, :, :, target_idx].detach().cpu().numpy()

        maps_clean = {
            "CLIP": sm_clean,
            "FS": fs_clean,
        }
        maps_adv = {
            "CLIP": sm_adv,
            "FS": fs_adv,
        }

        # Now iterate through the XAI table
        col = 0
        for label, func in XAI_METHODS:

            if "CLIP Clean" in label:
                show_map = sm_clean
            elif "CLIP Attack" in label:
                show_map = sm_adv
            elif "Diff" in label:
                show_map = diff
            elif "FaithShield-I Clean" in label:
                show_map = fs_clean
            elif "FaithShield-I Attack" in label:
                show_map = fs_adv
            elif "Clean" in label:   # XAI Clean
                show_map = func(image_clean)
            elif "Attack" in label:  # XAI Attack
                show_map = func(image_adv)
            else:
                continue

            axs[row][col].imshow(blend_map(show_map, cv2_img))
            axs[row][col].axis("off")

            if row == 0:
                axs[row][col].set_title(label, fontsize=10)

            col += 1

        axs[row][0].annotate(
            f"{model_name}, Label:{target_text}",
            xy=(-0.05, 0.5), xycoords='axes fraction',
            fontsize=10, rotation=90, ha='center', va='center'
        )

    #######################################################################
    # MATCH LAYOUT WITH plot_all_together
    #######################################################################
    fig.subplots_adjust(
        top=0.98,
        bottom=0.02,
        left=0.10,     # ← SAME as plot_all_together
        right=0.88,    # ← SAME as plot_all_together
        hspace=0.05,
        wspace=0.01
    )
    
    # ---- Paper-style thin colorbar CLOSE to right edge ----
    cmap = plt.get_cmap("jet")
    norm = mpl.colors.Normalize(vmin=0.0, vmax=1.0)
    sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    
    # EXACT SAME COLORBAR POSITION AS plot_all_together
    cax = fig.add_axes([0.885, 0.15, 0.004, 0.7])   # ← Do NOT change
    cbar = fig.colorbar(sm, cax=cax)
    cbar.ax.tick_params(labelsize=8)
    
    plt.show()
