{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e09d4317",
   "metadata": {},
   "source": [
    "# Motivational Figures\n",
    "\n",
    "This notebook generates figures used to motivate the online model selection problem and illustrate key concepts from the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb824045",
   "metadata": {},
   "source": [
    "## Setup and Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c7ff5b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as patches\n",
    "import matplotlib.image as mpimg\n",
    "import os\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from transformers import CLIPProcessor, CLIPModel\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "# --- Config ---\n",
    "metadata_path = \"../datasets/ms-coco/metadata.json\"\n",
    "image_root_dir = \"../datasets/ms-coco\"\n",
    "models = [\"Sana\", \"Unidiffuser\", \"LCM\", \"SDXL-Turbo\", \"SSD-1B\", \"Koala\"]\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# --- CLIP setup ---\n",
    "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(device).eval()\n",
    "clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
    "\n",
    "def get_clip_embedding(prompt):\n",
    "    inputs = clip_processor(text=[prompt], return_tensors=\"pt\", padding=True).to(device)\n",
    "    with torch.no_grad():\n",
    "        features = clip_model.get_text_features(**inputs)\n",
    "    return (features / features.norm(dim=-1, keepdim=True)).squeeze(0).cpu()\n",
    "\n",
    "def load_and_resize(path, size=(128, 128)):\n",
    "    img = Image.open(path).convert(\"RGB\")\n",
    "    return np.array(img.resize(size, Image.BICUBIC)) / 255.0\n",
    "\n",
    "# --- Load metadata ---\n",
    "with open(metadata_path, \"r\") as f:\n",
    "    raw = json.load(f)\n",
    "\n",
    "# Group by prompt\n",
    "prompt_map = {}\n",
    "for e in raw:\n",
    "    if e[\"model\"] not in models:\n",
    "        continue\n",
    "    p = e[\"prompt\"]\n",
    "    if e.get(\"clip_scores\") and len(e[\"clip_scores\"]) >= 4:\n",
    "        prompt_map.setdefault(p, {})[e[\"model\"]] = e\n",
    "\n",
    "# Filter: only prompts with all models and ≥ 4 images per model\n",
    "complete_prompts = [\n",
    "    p for p in prompt_map\n",
    "    if all(m in prompt_map[p] for m in models)\n",
    "    and all(len(prompt_map[p][m][\"clip_scores\"]) >= 4 for m in models)\n",
    "]\n",
    "print(f\"{len(complete_prompts)} prompts complets avec ≥ 4 images par modèle\")\n",
    "\n",
    "# Compute embeddings\n",
    "embeddings = {p: get_clip_embedding(p) for p in tqdm(complete_prompts, desc=\"CLIP embeddings\")}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2e35b49",
   "metadata": {},
   "source": [
    "## Figure 1: Model Performance Variability Across Prompts\n",
    "\n",
    "Shows how different T2I models perform on the same prompt, highlighting the best model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f64b22c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Prompt sélectionné ---\n",
    "selected_prompt = \"girl cracking an egg into a bowl of flour\"\n",
    "\n",
    "# --- Détection du meilleur modèle ---\n",
    "model_scores = {model: np.mean(prompt_map[selected_prompt][model][\"clip_scores\"][:4]) for model in models}\n",
    "best_model = max(model_scores, key=model_scores.get)\n",
    "\n",
    "# --- Fonction de plotting ---\n",
    "def plot_prompt_grid_highlight(prompt, save_path=None):\n",
    "    fig, axs = plt.subplots(1, len(models), figsize=(22, 5), facecolor='white')\n",
    "\n",
    "    for idx, model in enumerate(models):\n",
    "        ax = axs[idx]\n",
    "        entry = prompt_map[prompt][model]\n",
    "        imgs = entry[\"filenames\"][:4]\n",
    "        scores = entry[\"clip_scores\"][:4]\n",
    "        mean_score = sum(scores) / len(scores)\n",
    "\n",
    "        padding = 2\n",
    "        img_size = 128\n",
    "        grid_size = 2\n",
    "        canvas_size = img_size * grid_size + padding * (grid_size - 1)\n",
    "\n",
    "        # Image finale (fond blanc)\n",
    "        grid_img = np.ones((canvas_size, canvas_size, 3))\n",
    "\n",
    "        for i, img_path in enumerate(imgs):\n",
    "            img = load_and_resize(os.path.join(image_root_dir, img_path))\n",
    "            r, c = divmod(i, 2)\n",
    "            top = r * (img_size + padding)\n",
    "            left = c * (img_size + padding)\n",
    "            grid_img[top:top+img_size, left:left+img_size, :] = img\n",
    "\n",
    "        ax.imshow(grid_img)\n",
    "        ax.axis('off')\n",
    "        ax.set_title(f\"{model}\\nCLIPScore = {mean_score:.2f}\", fontsize=22, pad=5)\n",
    "\n",
    "        # Encadrer le meilleur modèle en vert\n",
    "        if model == best_model:\n",
    "            rect = patches.Rectangle(\n",
    "                (0, 0), canvas_size, canvas_size,\n",
    "                linewidth=15, edgecolor='limegreen', facecolor='none'\n",
    "            )\n",
    "            ax.add_patch(rect)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    if save_path:\n",
    "        os.makedirs(\"plots/motivation\", exist_ok=True)\n",
    "        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')\n",
    "    plt.show()\n",
    "\n",
    "# --- Affichage ---\n",
    "plot_prompt_grid_highlight(selected_prompt, save_path=\"plots/motivation/prompt_example.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41a3c376",
   "metadata": {},
   "source": [
    "## Figure 2: Algorithm Comparison on Single Prompt\n",
    "\n",
    "Shows how different algorithms select models for a specific prompt, highlighting suboptimal choices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2615101",
   "metadata": {},
   "outputs": [],
   "source": [
    "algos = [\"PAK-UCB\", \"LinUCB\", \"BALROG\", \"Optimal\"]\n",
    "image_size = 512\n",
    "\n",
    "def load_and_resize_large(path, size=image_size):\n",
    "    img = Image.open(path).convert(\"RGB\")\n",
    "    img = img.resize((size, size), Image.Resampling.LANCZOS)\n",
    "    return np.array(img)\n",
    "\n",
    "def get_best_model(prompt):\n",
    "    models_available = prompt_map[prompt].keys()\n",
    "    avg_scores = {m: np.mean(prompt_map[prompt][m][\"clip_scores\"][:4]) for m in models_available}\n",
    "    return max(avg_scores, key=avg_scores.get)\n",
    "\n",
    "def plot_single_prompt_comparison(selected_prompt, selection, save_path=None):\n",
    "    fig, axs = plt.subplots(1, len(algos), figsize=(len(algos)*3.5, 4))\n",
    "\n",
    "    best_model = get_best_model(selected_prompt)\n",
    "    for idx, algo in enumerate(algos):\n",
    "        ax = axs[idx]\n",
    "        model = selection[selected_prompt][algo]\n",
    "        entry = prompt_map[selected_prompt][model]\n",
    "        img_path = entry[\"filenames\"][0]\n",
    "        clip = entry[\"clip_scores\"][0]\n",
    "        img = load_and_resize_large(os.path.join(image_root_dir, img_path))\n",
    "\n",
    "        ax.imshow(img)\n",
    "        ax.axis('off')\n",
    "        is_suboptimal = (algo != \"Optimal\") and (model != best_model)\n",
    "        title_color = 'red' if is_suboptimal else 'black'\n",
    "        ax.set_title(f\"{algo} → {model}\\nCLIP = {clip:.2f}\", fontsize=9, color=title_color)\n",
    "\n",
    "    fig.suptitle(f'Prompt: \"{selected_prompt}\"', fontsize=11, y=1.07)\n",
    "    plt.tight_layout(pad=1.2)\n",
    "    if save_path:\n",
    "        os.makedirs(\"plots/motivation\", exist_ok=True)\n",
    "        plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "# Example selection (you need to provide actual algorithm selections)\n",
    "selected_prompt = \"celebrating their show : actor joins the cast and crew of science fiction tv program for the red carpet event\"\n",
    "\n",
    "selection = {\n",
    "    selected_prompt: {\n",
    "        \"PAK-UCB\": \"Unidiffuser\",\n",
    "        \"LinUCB\": \"Sana\",\n",
    "        \"BALROG\": \"SDXL-Turbo\",\n",
    "        \"Optimal\": \"SDXL-Turbo\"\n",
    "    }\n",
    "}\n",
    "\n",
    "plot_single_prompt_comparison(selected_prompt, selection, save_path=\"plots/motivation/algorithm_comparison.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72d989a1",
   "metadata": {},
   "source": [
    "## Figure 3: Tsybakov's Constant (α) Estimation\n",
    "\n",
    "Estimate the Tsybakov constant from the margin distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fb1128b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "\n",
    "# --- Configuration ---\n",
    "metadata_path = \"../datasets/carrot-bowl/metadata.json\"\n",
    "\n",
    "# --- Loading CLIP ---\n",
    "@torch.no_grad()\n",
    "def embed_prompts(prompts):\n",
    "    embeddings = []\n",
    "    for p in tqdm(prompts, desc=\"Embedding prompts\"):\n",
    "        inputs = clip_processor(text=[p], return_tensors=\"pt\", padding=True).to(device)\n",
    "        outputs = clip_model.get_text_features(**inputs)\n",
    "        emb = outputs / outputs.norm(dim=-1, keepdim=True)\n",
    "        embeddings.append(emb.cpu())\n",
    "    return torch.cat(embeddings)\n",
    "\n",
    "# --- Lecture du fichier JSON ---\n",
    "with open(metadata_path, \"r\") as f:\n",
    "    raw = json.load(f)\n",
    "\n",
    "# --- Organisation des données ---\n",
    "prompt_map_alpha = defaultdict(lambda: defaultdict(list))\n",
    "for e in raw:\n",
    "    if e.get(\"clip_scores\"):\n",
    "        prompt_map_alpha[e[\"prompt\"]][e[\"model\"]].extend(e[\"clip_scores\"])\n",
    "\n",
    "# Garder prompts où tous les modèles sont présents\n",
    "valid_prompts = [p for p in prompt_map_alpha if all(len(prompt_map_alpha[p][m]) > 0 for m in prompt_map_alpha[p])]\n",
    "print(f\"{len(valid_prompts)} valid prompts\")\n",
    "\n",
    "# --- Embeddings ---\n",
    "X = embed_prompts(valid_prompts).numpy()\n",
    "\n",
    "# --- Estimation dimension d (Levina-Bickel estimator) ---\n",
    "def estimate_intrinsic_dim(X, k=50):\n",
    "    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='auto').fit(X)\n",
    "    dists, _ = nbrs.kneighbors(X)\n",
    "    dists = dists[:, 1:]  # remove self distance\n",
    "\n",
    "    log_ratios = np.log(dists[:, -1] / dists[:, 0])\n",
    "    d_hat = (k - 1) / np.mean(log_ratios)\n",
    "    return d_hat\n",
    "\n",
    "d_hat = estimate_intrinsic_dim(X, k=10)\n",
    "print(f\"Estimated intrinsic dimension d ≈ {d_hat:.2f}\")\n",
    "\n",
    "# --- Estimation alpha de Tsybakov ---\n",
    "def estimate_tsybakov_alpha(prompt_map, prompts, n_bins=20):\n",
    "    margins = []\n",
    "    for p in prompts:\n",
    "        scores = [np.mean(prompt_map[p][m]) for m in prompt_map[p]]\n",
    "        if len(scores) < 2: continue\n",
    "        top2 = sorted(scores, reverse=True)[:2]\n",
    "        margins.append(top2[0] - top2[1])\n",
    "\n",
    "    margins = np.array(margins)\n",
    "    epsilons = np.logspace(-3, 0, n_bins)\n",
    "    probabilities = [(margins < eps).mean() for eps in epsilons]\n",
    "\n",
    "    log_eps = np.log(epsilons + 1e-12)\n",
    "    log_prob = np.log(np.array(probabilities) + 1e-12)\n",
    "\n",
    "    # Régression linéaire log-log : log P(eps) ≈ alpha * log(eps)\n",
    "    coeffs = np.polyfit(log_eps, log_prob, deg=1)\n",
    "    alpha_hat = coeffs[0]\n",
    "\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    plt.plot(log_eps, log_prob, 'o-', label=f\"Fit: α ≈ {alpha_hat:.2f}\", markersize=8, linewidth=2)\n",
    "    plt.xlabel(\"log ε\", fontsize=14)\n",
    "    plt.ylabel(\"log P(margin < ε)\", fontsize=14)\n",
    "    plt.title(\"Estimation of Tsybakov's Constant α\", fontsize=16)\n",
    "    plt.grid(True, alpha=0.3)\n",
    "    plt.legend(fontsize=12)\n",
    "    os.makedirs(\"plots/motivation\", exist_ok=True)\n",
    "    plt.savefig(\"plots/motivation/tsybakov_alpha_estimate.pdf\", dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "    return alpha_hat\n",
    "\n",
    "alpha_hat = estimate_tsybakov_alpha(prompt_map_alpha, valid_prompts)\n",
    "print(f\"Estimated Tsybakov alpha ≈ {alpha_hat:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60acecbc",
   "metadata": {},
   "source": [
    "## Figure 4: Error vs Number of Generations Regression\n",
    "\n",
    "Shows how error decreases with the number of image generations per prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b528374",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import curve_fit\n",
    "from sklearn.metrics import r2_score\n",
    "\n",
    "# --- Données (à adapter avec vos vraies données) ---\n",
    "g = np.array([1, 2, 3, 4, 5])\n",
    "delta_g = np.array([1.2702, 0.9315, 0.7951, 0.7152, 0.6671])\n",
    "\n",
    "# --- Modèle : Δ(g) = a + b / sqrt(g) ---\n",
    "def model(g, a, b):\n",
    "    return a + b / np.sqrt(g)\n",
    "\n",
    "# --- Ajustement ---\n",
    "params, _ = curve_fit(model, g, delta_g)\n",
    "a, b = params\n",
    "delta_fit = model(g, a, b)\n",
    "\n",
    "# --- R² ---\n",
    "r2 = r2_score(delta_g, delta_fit)\n",
    "\n",
    "# --- Tracé ---\n",
    "plt.figure(figsize=(8, 6))\n",
    "plt.plot(g, delta_g, 'o', markersize=10, label='Data', color='blue')\n",
    "plt.plot(g, delta_fit, '--', linewidth=2, label='Regression', color='red')\n",
    "plt.xlabel('Number of generations (g)', fontsize=14)\n",
    "plt.ylabel(r'Error $\\Delta(g)$', fontsize=14)\n",
    "plt.title(r'Regression: $\\Delta(g) = a + \\frac{b}{\\sqrt{g}}$', fontsize=16)\n",
    "\n",
    "# --- Annotation équation + R² ---\n",
    "eqn = r\"$\\Delta(g) = {:.4f} + \\frac{{{:.4f}}}{{\\sqrt{{g}}}},\\quad R^2 = {:.4f}$\".format(a, b, r2)\n",
    "plt.annotate(eqn, xy=(0.5, 0.95), xycoords='axes fraction',\n",
    "             ha='center', va='top', fontsize=12,\n",
    "             bbox=dict(boxstyle='round,pad=0.5', fc='wheat', ec='black', lw=1))\n",
    "\n",
    "plt.legend(fontsize=12)\n",
    "plt.grid(True, alpha=0.3)\n",
    "plt.tight_layout()\n",
    "\n",
    "os.makedirs(\"plots/motivation\", exist_ok=True)\n",
    "plt.savefig(\"plots/motivation/error_vs_generations.pdf\", dpi=300, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(f\"Fitted parameters: a = {a:.4f}, b = {b:.4f}\")\n",
    "print(f\"R² = {r2:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "376eb2a3",
   "metadata": {},
   "source": [
    "## Figure 5: Quantile Curves for Query Trigger Strategies\n",
    "\n",
    "Visualizes the distribution of different query trigger metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "035fa5f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "# --- 1. Load and sample prompts ---\n",
    "dataset = \"../datasets/flickr\"\n",
    "with open(f\"{dataset}/metadata.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "    raw = json.load(f)\n",
    "\n",
    "# Build score map per prompt\n",
    "selected_models = sorted({e[\"model\"] for e in raw})\n",
    "scores_map = defaultdict(lambda: defaultdict(list))\n",
    "for e in raw:\n",
    "    scores_map[e[\"prompt\"]][e[\"model\"]].extend(e.get(\"clip_scores\", []))\n",
    "\n",
    "valid_prompts = [p for p, m in scores_map.items() if all(model in m for model in selected_models)]\n",
    "print(f\"{len(valid_prompts)} valid prompts found with all models.\")\n",
    "\n",
    "random.seed(42)\n",
    "prompts_list = random.sample(valid_prompts, min(2000, len(valid_prompts)))\n",
    "\n",
    "# --- 2. Compute CLIP embeddings in batches ---\n",
    "embeddings = []\n",
    "batch_size = 32\n",
    "for i in range(0, len(prompts_list), batch_size):\n",
    "    batch = prompts_list[i:i+batch_size]\n",
    "    inputs = clip_processor(text=batch, return_tensors=\"pt\", padding=True).to(device)\n",
    "    with torch.no_grad():\n",
    "        feats = clip_model.get_text_features(**inputs)\n",
    "        feats = feats / feats.norm(dim=-1, keepdim=True)\n",
    "    embeddings.append(feats.cpu())\n",
    "X = torch.cat(embeddings, dim=0).numpy()\n",
    "\n",
    "# --- 3. Compute quantities ---\n",
    "delta_list = []\n",
    "var_list = []\n",
    "\n",
    "# Variant 1: top-two gap\n",
    "for p in prompts_list:\n",
    "    means = np.array([np.mean(scores_map[p][m]) for m in selected_models])\n",
    "    top2 = np.sort(means)[-2:]\n",
    "    delta_list.append(top2[1] - top2[0])\n",
    "\n",
    "# Variant 4: local neighbor variance\n",
    "sim = X.dot(X.T)\n",
    "dists = 1 - sim\n",
    "k = random.randint(20, 200)\n",
    "for i, p in enumerate(prompts_list):\n",
    "    nbrs = np.argsort(dists[i])[1:k+1]\n",
    "    for m in selected_models:\n",
    "        neighbor_rewards = []\n",
    "        for j in nbrs:\n",
    "            neighbor_rewards.extend(scores_map[prompts_list[j]][m])\n",
    "        var_list.append(np.var(neighbor_rewards))\n",
    "\n",
    "# Load UCB data if available\n",
    "try:\n",
    "    with open(\"ucbs.pkl\", \"rb\") as f:\n",
    "        ucbs_data = pickle.load(f)\n",
    "    ucb_list = ucbs_data['ucbs'][10:]\n",
    "except:\n",
    "    print(\"Warning: ucbs.pkl not found, skipping UCB plot\")\n",
    "    ucb_list = []\n",
    "\n",
    "# --- 4. Quantile curves ---\n",
    "alphas = np.linspace(0, 1, 1001)\n",
    "quant_delta = np.percentile(delta_list, 100 * alphas)\n",
    "quant_var = np.percentile(var_list, 100 * alphas)\n",
    "\n",
    "# --- 5. Plot ---\n",
    "plt.rcParams.update({\n",
    "    'font.size': 14,\n",
    "    'axes.titlesize': 14,\n",
    "    'axes.labelsize': 14,\n",
    "    'xtick.labelsize': 12,\n",
    "    'ytick.labelsize': 12\n",
    "})\n",
    "\n",
    "n_plots = 3 if len(ucb_list) > 0 else 2\n",
    "fig, axes = plt.subplots(1, n_plots, figsize=(6*n_plots, 5))\n",
    "\n",
    "if n_plots == 2:\n",
    "    axes = [axes[0], axes[1]]\n",
    "else:\n",
    "    axes = list(axes)\n",
    "\n",
    "# Variant 1: Delta\n",
    "axes[0].plot(alphas, quant_delta, color='blue', linewidth=2)\n",
    "axes[0].set_title(\"Top-Two Mean Gap (Δ)\")\n",
    "axes[0].set_xlabel(r\"Quantile level $\\alpha$\")\n",
    "axes[0].set_ylabel(\"Δ value\")\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "\n",
    "# Variant 4: Local Variance\n",
    "axes[1].plot(alphas, quant_var, color='green', linewidth=2)\n",
    "axes[1].set_title(\"Local Neighbor Reward Variance\")\n",
    "axes[1].set_xlabel(r\"Quantile level $\\alpha$\")\n",
    "axes[1].set_ylabel(\"Variance\")\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "# Variant 2: UCB Bonus (if available)\n",
    "if len(ucb_list) > 0:\n",
    "    quant_ucb = np.percentile(ucb_list, 100 * alphas)\n",
    "    axes[2].plot(alphas, quant_ucb, color='purple', linewidth=2)\n",
    "    axes[2].set_title(\"UCB Bonus\")\n",
    "    axes[2].set_xlabel(r\"Quantile level $\\alpha$\")\n",
    "    axes[2].set_ylabel(\"UCB bonus\")\n",
    "    axes[2].grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "os.makedirs(\"plots/motivation\", exist_ok=True)\n",
    "plt.savefig(\"plots/motivation/quantile_curves.pdf\", dpi=600, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e787197a",
   "metadata": {},
   "source": [
    "## Figure 6: Correlation Analysis - Distance vs Score Difference\n",
    "\n",
    "Shows the relationship between prompt similarity and model performance correlation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "811e4df9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import combinations\n",
    "from sklearn.metrics.pairwise import cosine_distances\n",
    "from scipy.stats import pearsonr\n",
    "from transformers import CLIPTokenizer, CLIPTextModel\n",
    "\n",
    "random.seed(42)\n",
    "\n",
    "# --- Charger metadata ---\n",
    "with open(\"../datasets/flowers/metadata.json\", \"r\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# --- Filtrer pour un modèle spécifique ---\n",
    "model_name = \"SDXL-Turbo\"\n",
    "entries = [d for d in data if d[\"model\"] == model_name]\n",
    "\n",
    "# Extraire prompts et scores\n",
    "prompts = [d[\"prompt\"] for d in entries]\n",
    "clip_means = [np.mean(d[\"clip_scores\"]) for d in entries]\n",
    "\n",
    "# Échantillonner\n",
    "if len(prompts) > 500:\n",
    "    sample_indices = random.sample(range(len(prompts)), 500)\n",
    "    prompts = [prompts[i] for i in sample_indices]\n",
    "    clip_means = [clip_means[i] for i in sample_indices]\n",
    "\n",
    "# --- Obtenir embeddings CLIP ---\n",
    "tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
    "text_model = CLIPTextModel.from_pretrained(\"openai/clip-vit-base-patch32\").to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = tokenizer(prompts, padding=True, truncation=True, return_tensors=\"pt\").to(device)\n",
    "    text_outputs = text_model(**inputs)\n",
    "    embeddings_corr = text_outputs.pooler_output.cpu().numpy()\n",
    "\n",
    "# --- Calculate distances et différences ---\n",
    "max_distance = 0.30\n",
    "distances = []\n",
    "diff_scores = []\n",
    "\n",
    "for (i, j) in combinations(range(len(prompts)), 2):\n",
    "    dist = cosine_distances([embeddings_corr[i]], [embeddings_corr[j]])[0, 0]\n",
    "    if 0.0001 < dist <= max_distance:\n",
    "        diff = abs(clip_means[i] - clip_means[j])\n",
    "        distances.append(dist)\n",
    "        diff_scores.append(diff)\n",
    "\n",
    "distances = np.array(distances)\n",
    "diff_scores = np.array(diff_scores)\n",
    "\n",
    "if len(distances) > 1:\n",
    "    corr, p_value = pearsonr(distances, diff_scores)\n",
    "    m, b = np.polyfit(distances, diff_scores, 1)\n",
    "    reg_line = m * distances + b\n",
    "\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    plt.scatter(distances, diff_scores, alpha=0.6, s=20, label='Prompt pairs')\n",
    "    plt.plot(distances, reg_line, color='red', linewidth=2, \n",
    "             label=f'Linear fit (r={corr:.2f}, p={p_value:.2e})')\n",
    "    plt.xlabel(\"Cosine distance between prompts (CLIP embeddings)\", fontsize=14)\n",
    "    plt.ylabel(\"Difference in CLIP score\", fontsize=14)\n",
    "    plt.title(f\"Prompt Distance vs Score Difference ({model_name})\", fontsize=16)\n",
    "    plt.legend(fontsize=12)\n",
    "    plt.grid(True, alpha=0.3)\n",
    "    os.makedirs(\"plots/motivation\", exist_ok=True)\n",
    "    plt.savefig(\"plots/motivation/correlation_analysis.pdf\", dpi=600, bbox_inches='tight')\n",
    "    plt.show()\n",
    "    \n",
    "    print(f\"Correlation: r={corr:.3f}, p-value={p_value:.2e}\")\n",
    "else:\n",
    "    print(\"Not enough pairs after filtering.\")"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
