import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode



def plot_all_together(
    model_name, orig_img_t, adv_img_t,
    cv2_img,
    sim_map_clean, sim_map_adv,
    sim_map_clean_fs, sim_map_adv_fs,
    all_texts, target_texts,
    topk_ratio=0.1
):
    def denorm_clip(img_1x3HW):
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=img_1x3HW.device).view(1,3,1,1)
        std  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=img_1x3HW.device).view(1,3,1,1)
        return (img_1x3HW * std + mean).clamp(0,1)

    def blend_map(sim_map, img):
        H, W = img.shape[:2]
        sim_map_norm = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-6)
        sim_map_resized = cv2.resize(sim_map_norm, (W, H))
        vis = (sim_map_resized * 255).astype('uint8')
        vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
        img_disp = (img*255).astype('uint8') if img.max() <= 1.0 else img.astype('uint8')
        return cv2.cvtColor(cv2.addWeighted(img_disp, 0.4, vis, 0.6, 0), cv2.COLOR_BGR2RGB)

    def topk_mask(sim_map, ratio=0.1):
        flat = sim_map.reshape(-1)
        k = max(1, int(ratio * flat.size))
        thr = np.partition(flat, -k)[-k]
        return (sim_map >= thr).astype(np.uint8)

    # --- denorm images ---
    orig_img = denorm_clip(orig_img_t).squeeze(0).permute(1,2,0).cpu().numpy()
    adv_img  = denorm_clip(adv_img_t).squeeze(0).permute(1,2,0).cpu().numpy()
    H, W = cv2_img.shape[:2]
    orig_resized = cv2.resize((orig_img*255).astype("uint8"), (W,H))
    adv_resized  = cv2.resize((adv_img*255).astype("uint8"), (W,H))

    n_targets = len(target_texts)
    fig, axs = plt.subplots(
    n_targets, 8,
    figsize=(20, 1.6*n_targets),   # was 2.5*n_targets → shrink row height
    gridspec_kw={'hspace': 0.05, 'wspace': 0.01}  # no extra vertical space
    )

    if n_targets == 1:
        axs = [axs]  # wrap single row

    panel_titles = [
        "Original Image", "Adversarial Image (X-Shift)",
        "CLIP Clean Heatmap", "CLIP Attack Heatmap",
        "FaithShield I Clean Heatmap", "FaithShield I Adversarial Heatmap",
        "FaithShield II Clean", "FaithShield II Adversarial"
    ]

    for row, target_text in enumerate(target_texts):
        idx = all_texts.index(target_text)  # get true index

        sim_map_clean_np = sim_map_clean[0, :, :, idx].detach().cpu().numpy()
        sim_map_adv_np   = sim_map_adv[0, :, :, idx].detach().cpu().numpy()
        map_clean_fs     = sim_map_clean_fs[0, :, :, idx].detach().cpu().numpy()
        map_adv_fs       = sim_map_adv_fs[0, :, :, idx].detach().cpu().numpy()

        # --- Stage II masks ---
        mask_clean = topk_mask(cv2.resize(sim_map_clean_np, (W,H)), ratio=topk_ratio)
        mask_adv   = topk_mask(cv2.resize(sim_map_adv_np, (W,H)), ratio=topk_ratio)
        masked_clean = orig_resized.copy().astype(float); masked_clean[mask_clean==1] = [255,0,0]
        masked_adv   = adv_resized.copy().astype(float);   masked_adv[mask_adv==1]   = [255,0,0]

        panels = [
            orig_resized,
            adv_resized,
            blend_map(sim_map_clean_np, cv2_img),
            blend_map(sim_map_adv_np, cv2_img),
            blend_map(map_clean_fs, cv2_img),
            blend_map(map_adv_fs, cv2_img),
            masked_clean.astype('uint8'),
            masked_adv.astype('uint8')
        ]

        for col, (ax, img, title) in enumerate(zip(axs[row], panels, panel_titles)):
            ax.imshow(img)
            ax.axis("off")
            if row == 0:
                ax.set_title(title, fontsize=9)

        axs[row][0].annotate(
            f"{model_name}",               # left label
            xy=(-0.08, 0.5), xycoords='axes fraction',
            fontsize=9, rotation=90, va='center', ha='center'
        )
        
        axs[row][-1].annotate(
            f"{target_text}",               # right label
            xy=(1.08, 0.5), xycoords='axes fraction',
            fontsize=9, rotation=270, va='center', ha='center'
        )

        # adjust margins + spacing between panels
    fig.subplots_adjust(
    top=0.98, bottom=0.02,
    left=0.10, right=0.995,
    hspace=0.05,   # tiny vertical spacing between rows
    wspace=0.01
    )



    plt.show()

def merged_heatmap_multi_targets(
    target_texts,
    all_texts,
    orig_img_t, adv_img_t,
    sim_map_clean_clip, sim_map_adv_clip,
    sim_map_clean_fs, sim_map_adv_fs,
    cv2_img
):
    

    def denorm_clip(img_1x3HW):
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=img_1x3HW.device).view(1,3,1,1)
        std  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=img_1x3HW.device).view(1,3,1,1)
        return (img_1x3HW * std + mean).clamp(0,1)

    def blend_map(sim_map, img):
        H, W = img.shape[:2]  # match target image size
    
        # --- normalize sim_map to [0,255] ---
        sim_map_norm = (sim_map - sim_map.min()) / (sim_map.max() - sim_map.min() + 1e-6)
        sim_map_resized = cv2.resize(sim_map_norm, (W, H))  # resize heatmap to match image
    
        # --- apply colormap ---
        vis = (sim_map_resized * 255).astype('uint8')
        vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
    
        # --- blend with original image ---
        if img.max() <= 1.0:  # if your cv2_img is float [0,1], scale to [0,255]
            img_disp = (img * 255).astype('uint8')
        else:
            img_disp = img.astype('uint8')
    
        blend = cv2.addWeighted(img_disp, 0.4, vis, 0.6, 0)
        return cv2.cvtColor(blend, cv2.COLOR_BGR2RGB)



    # --- denormalize images ---
    orig_img = denorm_clip(orig_img_t).squeeze(0).permute(1,2,0).cpu().numpy()
    adv_img  = denorm_clip(adv_img_t).squeeze(0).permute(1,2,0).cpu().numpy()
    H, W = cv2_img.shape[:2]
    orig_img_resized = cv2.resize((orig_img*255).astype("uint8"), (W,H))
    adv_img_resized  = cv2.resize((adv_img*255).astype("uint8"), (W,H))

    n_targets = len(target_texts)
    fig, axes = plt.subplots(n_targets, 6, figsize=(22, 4*n_targets))

    if n_targets == 1:
        axes = [axes]  # make iterable if only one row

    for row, target_text in enumerate(target_texts):
        ax_row = axes[row]
        idx = all_texts.index(target_text)  # find index of this label

        # slice correct maps for this label
        map_clean_clip = sim_map_clean_clip[0, :, :, idx].detach().cpu().numpy()
        map_adv_clip   = sim_map_adv_clip[0, :, :, idx].detach().cpu().numpy()
        map_clean_fs   = sim_map_clean_fs[0, :, :, idx].detach().cpu().numpy()
        map_adv_fs     = sim_map_adv_fs[0, :, :, idx].detach().cpu().numpy()


        titles = [
            "Original",
            "Adversarial",
            "CLIP Clean",
            "CLIP Adversarial",
            "FaithShield–Stage I Clean",
            "FaithShield–Stage I Adversarial"
        ]
        imgs = [
            orig_img_resized,
            adv_img_resized,
            blend_map(map_clean_clip, cv2_img),
            blend_map(map_adv_clip, cv2_img),
            blend_map(map_clean_fs, cv2_img),
            blend_map(map_adv_fs, cv2_img)
        ]

        for ax, title, img in zip(ax_row, titles, imgs):
            ax.imshow(img)
            ax.set_title(title, fontsize=11)
            ax.axis("off")

        # add vertical target label
        ax_row[0].set_ylabel(f"Target: {target_text}", fontsize=12, rotation=90, labelpad=5)

    plt.tight_layout()
    plt.show()