import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode


# --- Utility ---
def display_similarity(similarity_map, cv2_img, text_label, phase):
    vis = (similarity_map * 255).astype('uint8')
    vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
    blend = cv2_img * 0.4 + vis * 0.6
    blend = cv2.cvtColor(blend.astype('uint8'), cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(5, 5))
    plt.title(f'CLIP Surgery ({phase}) - {text_label}')
    plt.imshow(blend)
    plt.axis('off')
    plt.show()


def has_tokens(z):  # z is output of encode_image
    return z.dim() == 3 and z.shape[1] > 1

def denorm_clip(img_1x3HW):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=img_1x3HW.device).view(1,3,1,1)
    std  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=img_1x3HW.device).view(1,3,1,1)
    return (img_1x3HW * std + mean).clamp(0,1)


def cosine_sim(a, b):
    return F.cosine_similarity(a, b).item()

def max_delta_prob(logits1, logits2):
    p1, p2 = F.softmax(logits1, dim=-1), F.softmax(logits2, dim=-1)
    return (p1 - p2).abs().max().item()


def to_numpy(x):
    return x.detach().cpu().numpy()


def encode_and_norm(model, image):
    """Encode image and L2-normalize features."""
    z = model.encode_image(image)
    return z / (z.norm(dim=-1, keepdim=True) + 1e-10)

def build_similarity_map(z, text_features, remove_cls=True):
    """
    Compute patch-text similarity maps.
    z: [B, T, D] tokens
    text_features: [N, D]
    returns: [B, H, W, N]
    """
    s = z @ text_features.T  # [B, T, N]
    if remove_cls:
        s = s[:, 1:, :]  # drop CLS token
        tokens = z.shape[1] - 1
    else:
        tokens = z.shape[1]

    H = W = int(tokens ** 0.5)
    assert H * W == tokens, f"Cannot reshape {tokens} tokens into square grid"
    return s.reshape(z.shape[0], H, W, -1).contiguous()

def compare_similarity_maps(map1, map2, label1, label2, cv2_img, title):
    def blend_map(sim_map, img):
        vis = (sim_map * 255).astype('uint8')
        vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
        blend = img * 0.4 + vis * 0.6
        return cv2.cvtColor(blend.astype('uint8'), cv2.COLOR_BGR2RGB)

    plt.figure(figsize=(10, 4))
    plt.suptitle(title, fontsize=14)

    plt.subplot(1, 2, 1)
    plt.imshow(blend_map(map1, cv2_img))
    plt.title(label1)
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(blend_map(map2, cv2_img))
    plt.title(label2)
    plt.axis('off')

    plt.tight_layout()
    plt.show()