#-------------
# 2025.09.23
# Quantitative evaluation code for ControlSwap
# run: python quan.py > metrics_results_ControlSwap.txt
#-------------

import os
import glob
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import clip
import lpips
from skimage.metrics import peak_signal_noise_ratio, structural_similarity, mean_squared_error
import sys
import re



# ------------------------------
# util
# ------------------------------
def load_image(path, size=None):
    img = Image.open(path).convert("RGB")
    if size:
        img = img.resize(size, Image.LANCZOS)
    return img

def apply_mask(img_tensor, mask_tensor):
    return img_tensor * mask_tensor

def pil_to_tensor(img):
    return torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0

def tensor_to_pil(tensor):
    return Image.fromarray((tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))

# ------------------------------
# metrics
# ------------------------------
def compute_clip_image_similarity(img1, img2, model, preprocess, device):
    img1_pre = preprocess(img1).unsqueeze(0).to(device)
    img2_pre = preprocess(img2).unsqueeze(0).to(device)
    with torch.no_grad():
        feat1 = model.encode_image(img1_pre)
        feat2 = model.encode_image(img2_pre)
    feat1 /= feat1.norm(dim=-1, keepdim=True)
    feat2 /= feat2.norm(dim=-1, keepdim=True)
    return (feat1 @ feat2.T).item()

def compute_clip_text_similarity(text, img, model, preprocess, device):
    img_pre = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        img_feat = model.encode_image(img_pre)
        text_feat = model.encode_text(clip.tokenize([text]).to(device))
    img_feat /= img_feat.norm(dim=-1, keepdim=True)
    text_feat /= text_feat.norm(dim=-1, keepdim=True)
    return (img_feat @ text_feat.T).item()

device = "cuda" if torch.cuda.is_available() else "cpu"

# CLIP / LPIPS
model, preprocess = clip.load("ViT-B/32", device=device)
lpips_fn = lpips.LPIPS(net='alex').to(device)

def compute_metrics(subject_img, source_img, target_img, target_prompt, mask_img, diff_prompt):
    size = (1024, 1024)
    subject_img = subject_img.resize(size, Image.LANCZOS)
    source_img  = source_img.resize(size,  Image.LANCZOS)
    target_img  = target_img.resize(size,  Image.LANCZOS)
    mask_img    = mask_img.resize(size,    Image.NEAREST)

    target_tensor  = pil_to_tensor(target_img).to(device)
    subject_tensor = pil_to_tensor(subject_img).to(device)
    source_tensor  = pil_to_tensor(source_img).to(device)
    mask_tensor    = (pil_to_tensor(mask_img).mean(dim=0, keepdim=True) > 0.5).float().to(device)  # (1,H,W)
    inv_mask_tensor = 1 - mask_tensor

    # (1) FG: CLIP-I (subject vs target within mask)
    masked_subject = apply_mask(subject_tensor, mask_tensor)
    masked_target  = apply_mask(target_tensor,  mask_tensor)
    clip_i_score = compute_clip_image_similarity(
        tensor_to_pil(subject_tensor), tensor_to_pil(masked_target),
        model, preprocess, device
    )

    # (2) BG: PSNR/LPIPS/MSE/SSIM (source vs target on background)
    masked_source_bg = apply_mask(source_tensor, inv_mask_tensor)
    masked_target_bg = apply_mask(target_tensor, inv_mask_tensor)

    source_bg_np = np.array(tensor_to_pil(masked_source_bg)) / 255.0
    target_bg_np = np.array(tensor_to_pil(masked_target_bg)) / 255.0

    psnr_val = peak_signal_noise_ratio(source_bg_np, target_bg_np, data_range=1.0)
    mse_val  = mean_squared_error(source_bg_np, target_bg_np)
    ssim_val = structural_similarity(source_bg_np, target_bg_np, channel_axis=-1, data_range=1.0)
    lpips_val = lpips_fn(masked_source_bg.unsqueeze(0), masked_target_bg.unsqueeze(0)).item()

    # (3) CLIP-T (all)
    clip_t_score = compute_clip_text_similarity(target_prompt, target_img, model, preprocess, device)

    # (3a) CLIP-T (mask=1)
    clip_t_fg_score = compute_clip_text_similarity(
        diff_prompt, tensor_to_pil(masked_target), model, preprocess, device
    )

    return {
        "CLIP-I (mask=1)": clip_i_score,
        "CLIP-T (mask=1)": clip_t_fg_score,
        "PSNR (mask=0)": psnr_val,
        "LPIPS (mask=0)": lpips_val,
        "MSE (mask=0)": mse_val,
        "SSIM (mask=0)": ssim_val,
        "CLIP-T (all)": clip_t_score,
    }

def find_latest_sample_dir(base_dir: str):
    """base_dir 'sample_<num>' for biggest num. or None"""
    if not os.path.isdir(base_dir):
        return None
    candidates = []
    for name in os.listdir(base_dir):
        p = os.path.join(base_dir, name)
        if os.path.isdir(p) and name.startswith("sample_"):
            m = re.match(r"sample_(\d+)$", name)
            if m:
                candidates.append((int(m.group(1)), name))
    if not candidates:
        return None
    candidates.sort()
    return candidates[-1][1]  # 가장 큰 숫자의 폴더명


# classes
classes = [
    "dog6",
    "backpack_dog",
    "duck_toy",
    "monster_toy",
    "poop_emoji",
    "clock",
    "pink_sunglasses",
    "robot_toy",
]

BASE_RESULTS = "./ConSwapBench"
BASE_SWAPBENCH = "./ConSwapBench/SwapBench"
BASE_DREAMBOOTH = "./dataset"

target_prompt_file = "./InstantSwap/target_prompt.txt"
with open(target_prompt_file, "r", encoding="utf-8") as f:
    target_prompt_templates = [line.strip() for line in f]

diff_prompt_target = {
    "dog6": "sks dog playing soccer, with a soccerball",
    "backpack_dog": "sks backpack, in purple color",
    "duck_toy": "sks duck toy",
    "monster_toy": "sks monster toy, with two arms raised",
    "poop_emoji": "sks poop emoji, in yellow color",
    "clock": "sks clock, made of wood, with a natural wooden texture",
    "pink_sunglasses": "sks pink sunglasses, shining brightly under the sunlight",
    "robot_toy": "sks robot toy, with a rusty metallic body",
}

all_results = []
# "SDS" "DDS" > output_image.png
# "InstantSwap550_guidance_7.5_interval_5" > final.png
# "CDS" "DDSSwap" "ControlSwap" > output_image.png
# "PhotoSwap" > sample_n  > output_image.png
# "SwapAnything" > sample_n > large_result.png
# "pnp" > directinversion+p2p.jpg

method = "ControlSwap" ######

for cls in classes:
    print(f"=== Evaluating class: {cls} ===")
    class_results = []
    class_results_sum = {k: 0.0 for k in [
        "CLIP-I (mask=1)", "CLIP-T (mask=1)", "PSNR (mask=0)", "LPIPS (mask=0)", "MSE (mask=0)", "SSIM (mask=0)", "CLIP-T (all)"
    ]}
    count = 0

    for i in tqdm(range(1, 161), file=sys.stderr):
        # Target image path (per method)
        if method in {"SDS", "DDS"}:
            target_image_path = os.path.join(BASE_RESULTS, cls, "results", str(i), method, "output_image.png")
            target_img = load_image(target_image_path)
            w, h = target_img.size
            target_img = target_img.crop((w // 2, 0, w, h))

        elif method == "InstantSwap550_guidance_7.5_interval_5":
            target_image_path = os.path.join(BASE_RESULTS, cls, "results", str(i), method, "final.png")
            target_img = load_image(target_image_path)

        elif method in {"CDS", "DDSSwap", "ControlSwap"}:
            target_image_path = os.path.join(BASE_RESULTS, cls, "results", str(i), method, "output_image.png")
            target_img = load_image(target_image_path)

        elif method == "PhotoSwap":
            base_dir = os.path.join(BASE_RESULTS, cls, "results", str(i), method)
            sample_dir = find_latest_sample_dir(base_dir) or "sample_0"
            target_image_path = os.path.join(base_dir, sample_dir, "output_image.png")
            target_img = load_image(target_image_path)

        elif method == "SwapAnything":
            base_dir = os.path.join(BASE_RESULTS, cls, "results", str(i), method)
            sample_dir = find_latest_sample_dir(base_dir) or "sample_0"
            target_image_path = os.path.join(base_dir, sample_dir, "large_result.png")
            target_img = load_image(target_image_path)

        elif method == "pnp":
            target_image_path = os.path.join(BASE_RESULTS, cls, "results", str(i), method, "directinversion+p2p.jpg")
            target_img = load_image(target_image_path)
            w, h = target_img.size
            
            x0 = max(0, w - 512)
            y0 = 0
            x1 = w
            y1 = min(h, 512)
            target_img = target_img.crop((x0, y0, x1, y1))

        else:
            raise ValueError(f"Method name error: {method}")
            
        if not os.path.exists(target_image_path):
            continue
        

        # Source image path (any .jpg except GT_bbox.jpg)
        source_folder = os.path.join(BASE_SWAPBENCH, str(i))
        source_images = [f for f in glob.glob(os.path.join(source_folder, "*.jpg")) if not f.endswith("GT_bbox.jpg")]
        if not source_images:
            continue
        source_image_path = source_images[0]

        # Mask path
        mask_path = os.path.join(source_folder, "GT_bbox.jpg")

        # Subject image path (00.jpg)
        subject_image_path = os.path.join(BASE_DREAMBOOTH, cls, "00.jpg")
        if not os.path.exists(subject_image_path):
            continue

        if i <= len(target_prompt_templates):
            template = target_prompt_templates[i - 1]  # i번째 줄
            target_prompt = template.replace("{ }", diff_prompt_target[cls])
        else:
            continue

        # load images
        subject_img = load_image(subject_image_path)
        source_img = load_image(source_image_path)
        mask_img = load_image(mask_path)

        # metric calc
        metrics = compute_metrics(subject_img, source_img, target_img, target_prompt, mask_img, diff_prompt_target[cls])
        
        # add
        for k in class_results_sum.keys():
            class_results_sum[k] += metrics[k]
        count += 1

    # avg
    if count > 0:
        avg_results = {k: v / count for k, v in class_results_sum.items()}
        avg_results["count"] = count
        all_results.append((cls, avg_results))
        print(f"[{cls}] Average over {count} samples:", avg_results)
    else:
        print(f"[{cls}] No valid samples found.")

# all avg
if all_results:
    total_sum = {k: 0.0 for k in all_results[0][1] if k != "count"}
    total_count = 0
    for _, avg_res in all_results:
        for k in total_sum.keys():
            total_sum[k] += avg_res[k]
        total_count += 1
    total_avg = {k: v / total_count for k, v in total_sum.items()}
    print("\n=== Overall Average across all classes ===")
    for k, v in total_avg.items():
        print(f"{k}: {v:.4f}")
