{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7600c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json, random, os\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader, Subset\n",
    "import torchvision, torchvision.transforms as T\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "import pandas as pd\n",
    "\n",
    "from encoder_manager import EncoderManager\n",
    "\n",
    "SEEDS = [0]\n",
    "DATA_ROOT = \"./cifar10/data\"\n",
    "SAMPLES_PER_CLASS = 100\n",
    "BATCH_SIZE = 128\n",
    "OUT_DIR = \"./results_geom\"\n",
    "\n",
    "ENC_CFG_BASE = {\n",
    "    \"type\": \"vit\",\n",
    "    \"size\": \"large\",               \n",
    "    \"patch\": 16,                   \n",
    "    \"from_pretrained\": False,      \n",
    "}\n",
    "\n",
    "def set_seed(seed=0):\n",
    "    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "def build_encoder(enc_cfg, device):\n",
    "    cfg = {\"encoder\": enc_cfg}\n",
    "    enc = EncoderManager(cfg, device)\n",
    "    return enc\n",
    "\n",
    "def cifar10_subset(transform, per_class=SAMPLES_PER_CLASS, seed=0):\n",
    "    ds = torchvision.datasets.CIFAR10(root=DATA_ROOT, train=False, download=True, transform=transform)\n",
    "    targets = np.array(ds.targets, dtype=np.int64)\n",
    "    rng = np.random.RandomState(seed)\n",
    "\n",
    "    indices = []\n",
    "    for c in range(10):\n",
    "        idx = np.where(targets == c)[0]\n",
    "        rng.shuffle(idx)\n",
    "        take = idx[:per_class]\n",
    "        indices.extend(take.tolist())\n",
    "    indices = sorted(indices)\n",
    "\n",
    "    subset = Subset(ds, indices)\n",
    "    loader = DataLoader(subset, batch_size=len(subset), shuffle=False, num_workers=2, pin_memory=True)\n",
    "    x, _ = next(iter(loader))\n",
    "    y = torch.tensor([ds.targets[i] for i in indices], dtype=torch.long)\n",
    "    return x, y\n",
    "\n",
    "@torch.no_grad()\n",
    "def encode_all(encoder, x, device, bs=BATCH_SIZE):\n",
    "    feats = []\n",
    "    for i in range(0, x.size(0), bs):\n",
    "        xb = x[i:i+bs].to(device, non_blocking=True)\n",
    "        fb = encoder.encode(xb).detach().cpu()\n",
    "        feats.append(fb)\n",
    "    z = torch.cat(feats, dim=0).float()\n",
    "    z = z - z.mean(dim=0, keepdim=True)\n",
    "    z = F.normalize(z, dim=1)\n",
    "    return z\n",
    "\n",
    "@torch.no_grad()\n",
    "def alignment_same_class(z, y):\n",
    "    \"\"\"\n",
    "    Alignment (Wang & Isola, 2020)\n",
    "    Align = E_{i,j: yi=yj} ||z_i - z_j||^2  (정규화 임베딩 → 2-2*cos)\n",
    "    \"\"\"\n",
    "    z = z.cpu(); y = y.cpu()\n",
    "    n = y.numel()\n",
    "    sim = z @ z.t()\n",
    "    yi = y.view(-1,1).expand(n,n)\n",
    "    same = (yi == yi.t())\n",
    "    iu = np.triu_indices(n, k=1)\n",
    "    same_upper = same[iu]\n",
    "    d2 = (2.0 - 2.0 * sim)[iu].numpy()\n",
    "    if same_upper.sum() == 0: return float('nan')\n",
    "    return d2[same_upper.numpy()].mean().item()\n",
    "\n",
    "@torch.no_grad()\n",
    "def intra_inter_and_ratio(z, y):\n",
    "    \"\"\"\n",
    "    Intra: 각 샘플 → 자기 class centroid L2 평균\n",
    "    Inter: centroid 쌍 L2 평균\n",
    "    Ratio: Intra / Inter\n",
    "    \"\"\"\n",
    "    z = z.cpu(); y = y.cpu()\n",
    "    K = int(y.max().item() + 1)\n",
    "    cents = []\n",
    "    for c in range(K):\n",
    "        idx = (y == c).nonzero(as_tuple=True)[0]\n",
    "        cents.append(z.index_select(0, idx).mean(dim=0) if idx.numel() > 0 else torch.zeros(z.size(1)))\n",
    "    C = torch.stack(cents, dim=0)\n",
    "\n",
    "    dsum, n = 0.0, 0\n",
    "    for c in range(K):\n",
    "        idx = (y == c).nonzero(as_tuple=True)[0]\n",
    "        if idx.numel() == 0: continue\n",
    "        diff = z.index_select(0, idx) - C[c].unsqueeze(0)\n",
    "        dsum += torch.norm(diff, dim=1).sum().item()\n",
    "        n += idx.numel()\n",
    "    intra = dsum / max(1, n)\n",
    "\n",
    "    if K > 1:\n",
    "        D = torch.cdist(C, C, p=2.0)\n",
    "        iu = np.triu_indices(K, k=1)\n",
    "        inter = D[iu].mean().item()\n",
    "    else:\n",
    "        inter = float('nan')\n",
    "\n",
    "    ratio = intra / inter if (inter is not None and inter > 0) else float('nan')\n",
    "    return intra, inter, ratio\n",
    "\n",
    "def _pca_then_tsne(Z, seed=0):\n",
    "    X = Z.cpu().numpy()\n",
    "    if X.shape[1] > 50:\n",
    "        X = PCA(n_components=50, random_state=0).fit_transform(X)\n",
    "    perplexity = max(5, min(50, X.shape[0] // 100))\n",
    "    tsne = TSNE(\n",
    "        n_components=2,\n",
    "        perplexity=perplexity,\n",
    "        init=\"pca\",\n",
    "        random_state=seed,\n",
    "        learning_rate=\"auto\",\n",
    "        max_iter=1500,\n",
    "    )\n",
    "    return tsne.fit_transform(X)\n",
    "\n",
    "def tsne_pair_figure(z_pre, z_rnd, y, out_path, seed=0):\n",
    "    emb_pre = _pca_then_tsne(z_pre, seed=seed)\n",
    "    emb_rnd = _pca_then_tsne(z_rnd, seed=seed)\n",
    "    y_np = y.cpu().numpy()\n",
    "\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(6, 3))\n",
    "\n",
    "    for emb, title, ax in zip([emb_pre, emb_rnd],\n",
    "                              [\"pretrained True\", \"pretrained False\"],\n",
    "                              axes):\n",
    "        ax.scatter(\n",
    "            emb[:, 0], emb[:, 1],\n",
    "            c=y_np, s=10, alpha=0.85,\n",
    "            cmap=\"tab10\", linewidths=0, rasterized=True\n",
    "        )\n",
    "        ax.set_title(title, fontsize=12, pad=6)\n",
    "\n",
    "        ax.set_xticks([]); ax.set_yticks([])\n",
    "\n",
    "        for side in [\"left\", \"right\", \"top\", \"bottom\"]:\n",
    "            ax.spines[side].set_visible(True)\n",
    "            ax.spines[side].set_color(\"black\")\n",
    "            ax.spines[side].set_linewidth(1.0)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(out_path, dpi=300, bbox_inches=\"tight\", facecolor=\"white\")\n",
    "    plt.close()\n",
    "    print(f\"saved: {out_path}\")\n",
    "\n",
    "\n",
    "\n",
    "def run_one_seed(seed, device, out_dir):\n",
    "    set_seed(seed)\n",
    "\n",
    "    enc_pre_cfg = json.loads(json.dumps(ENC_CFG_BASE));  enc_pre_cfg[\"from_pretrained\"] = True\n",
    "    enc_rnd_cfg = json.loads(json.dumps(ENC_CFG_BASE));  enc_rnd_cfg[\"from_pretrained\"] = False\n",
    "    enc_pre = build_encoder(enc_pre_cfg, device)\n",
    "    enc_rnd = build_encoder(enc_rnd_cfg, device)\n",
    "\n",
    "    transform = enc_pre.image_transform() if hasattr(enc_pre, \"image_transform\") \\\n",
    "        else T.Compose([T.Resize((224,224)), T.ToTensor()])\n",
    "\n",
    "    x, y = cifar10_subset(transform, per_class=SAMPLES_PER_CLASS, seed=seed)\n",
    "    assert x.size(0) == 10 * SAMPLES_PER_CLASS, \"항상 1000개 사용\"\n",
    "\n",
    "    z_pre = encode_all(enc_pre, x, device, bs=BATCH_SIZE)\n",
    "    z_rnd = encode_all(enc_rnd, x, device, bs=BATCH_SIZE)\n",
    "\n",
    "    align_pre = alignment_same_class(z_pre, y)\n",
    "    align_rnd = alignment_same_class(z_rnd, y)\n",
    "    intra_p, inter_p, ratio_p = intra_inter_and_ratio(z_pre, y)\n",
    "    intra_r, inter_r, ratio_r = intra_inter_and_ratio(z_rnd, y)\n",
    "\n",
    "    if seed == 0:\n",
    "        os.makedirs(out_dir, exist_ok=True)\n",
    "        pair_path = os.path.join(out_dir, \"tsne_pretrained_true_false_seed0.png\")\n",
    "        tsne_pair_figure(z_pre, z_rnd, y, pair_path, seed=0)\n",
    "\n",
    "    print(f\"[seed={seed}] Alignment  : pre={align_pre:.4f} | rnd={align_rnd:.4f} (↓)\")\n",
    "    print(f\"[seed={seed}] Intra      : pre={intra_p:.4f} | rnd={intra_r:.4f} (↓)\")\n",
    "    print(f\"[seed={seed}] Inter      : pre={inter_p:.4f} | rnd={inter_r:.4f} (↑)\")\n",
    "    print(f\"[seed={seed}] Ratio      : pre={ratio_p:.4f} | rnd={ratio_r:.4f} (↓)\")\n",
    "\n",
    "    return [\n",
    "        {\"seed\":seed, \"mode\":\"pretrained\", \"metric\":\"alignment\", \"value\":align_pre},\n",
    "        {\"seed\":seed, \"mode\":\"random\",     \"metric\":\"alignment\", \"value\":align_rnd},\n",
    "        {\"seed\":seed, \"mode\":\"pretrained\", \"metric\":\"intra\",     \"value\":intra_p},\n",
    "        {\"seed\":seed, \"mode\":\"random\",     \"metric\":\"intra\",     \"value\":intra_r},\n",
    "        {\"seed\":seed, \"mode\":\"pretrained\", \"metric\":\"inter\",     \"value\":inter_p},\n",
    "        {\"seed\":seed, \"mode\":\"random\",     \"metric\":\"inter\",     \"value\":inter_r},\n",
    "        {\"seed\":seed, \"mode\":\"pretrained\", \"metric\":\"ratio\",     \"value\":ratio_p},\n",
    "        {\"seed\":seed, \"mode\":\"random\",     \"metric\":\"ratio\",     \"value\":ratio_r},\n",
    "    ]\n",
    "\n",
    "def main():\n",
    "    os.makedirs(OUT_DIR, exist_ok=True)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    all_rows = []\n",
    "    for s in SEEDS:\n",
    "        rows = run_one_seed(s, device, OUT_DIR)\n",
    "        all_rows.extend(rows)\n",
    "\n",
    "    df_long = pd.DataFrame(all_rows)\n",
    "    df_long.to_csv(os.path.join(OUT_DIR, \"geom_metrics_all_seeds_long.csv\"), index=False)\n",
    "\n",
    "    summary = (df_long.groupby([\"mode\",\"metric\"])[\"value\"]\n",
    "               .agg([\"mean\",\"std\",\"count\"])\n",
    "               .reset_index())\n",
    "    summary.to_csv(os.path.join(OUT_DIR, \"geom_metrics_summary.csv\"), index=False)\n",
    "\n",
    "    print(\"\\n=== Summary (mean ± std over seeds) ===\")\n",
    "    for _, row in summary.iterrows():\n",
    "        print(f\"[{row['mode']:10s}] {row['metric']:10s} : {row['mean']:.4f} ± {row['std']:.4f}  (n={int(row['count'])})\")\n",
    "    print(f\"\\nSaved under: {OUT_DIR}\")\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pdj_dt",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
