{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53e65d1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import time\n",
    "import requests\n",
    "import re\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "from collections import defaultdict\n",
    "\n",
    "# ===== API Wrapper =====\n",
    "class LLM_API_Wrapper:\n",
    "    def __init__(self, model, api_key):\n",
    "        self.model = model\n",
    "        self.api_key = api_key\n",
    "\n",
    "    def call_api(self, prompt_text: str, n: int = 1, timeout_s: int = 90):\n",
    "        url = \"\"\n",
    "        payload = {\n",
    "            \"model\": self.model,\n",
    "            \"temperature\": 0.6,                 # 略降温，提升结构化输出概率\n",
    "            \"n\": n,\n",
    "            \"messages\": [{\"role\": \"user\", \"content\": prompt_text}],\n",
    "            \"modalities\": [\"text\"],\n",
    "            \"response_format\": {\"type\": \"text\"},\n",
    "            \"max_completion_tokens\": 512,\n",
    "            \"stream\": False\n",
    "        }\n",
    "        headers = {\n",
    "            \"Content-Type\": \"application/json\",\n",
    "            \"Authorization\": f\"Bearer {self.api_key}\"\n",
    "        }\n",
    "        try:\n",
    "            resp = requests.post(url, headers=headers, json=payload, timeout=timeout_s)\n",
    "            resp.raise_for_status()\n",
    "            data = resp.json()\n",
    "            if \"choices\" in data and data[\"choices\"]:\n",
    "                return [c[\"message\"][\"content\"] for c in data[\"choices\"]]\n",
    "            print(\"⚠️ Missing 'choices' in response.\")\n",
    "            return []\n",
    "        except requests.exceptions.RequestException as e:\n",
    "            print(f\"❌ Request failed: {e}\")\n",
    "            return []\n",
    "\n",
    "# ===== Prompt Builders =====\n",
    "def build_full_analysis_prompt(text: str) -> str:\n",
    "    # 严格要求只返回 JSON\n",
    "    return f\"\"\"\n",
    "You are a psycholinguistics expert. Analyze the following social media post from three perspectives:\n",
    "\n",
    "1) Semantic Summary: main idea or intention.\n",
    "2) Sentiment Analysis: emotions/attitudes.\n",
    "3) Linguistic Style: writing style (e.g., emotional, rational, informal, formal, vague).\n",
    "\n",
    "Return ONLY valid JSON with the exact keys below and no extra text:\n",
    "\n",
    "{{\n",
    "  \"semantic_view\": \"...\",\n",
    "  \"sentiment_view\": \"...\",\n",
    "  \"linguistic_view\": \"...\"\n",
    "}}\n",
    "\n",
    "Post:\n",
    "\\\"\\\"\\\"{(text or '').strip()[:1024]}\\\"\\\"\\\"\n",
    "\"\"\".strip()\n",
    "\n",
    "def build_fallback_prompt(text: str) -> str:\n",
    "    # 退而求其次的更简单提示\n",
    "    return f\"\"\"\n",
    "Provide a STRICT JSON object with three short fields summarizing the post:\n",
    "\n",
    "{{\n",
    "  \"semantic_view\": \"<1-2 sentences>\",\n",
    "  \"sentiment_view\": \"<one or two emotions>\",\n",
    "  \"linguistic_view\": \"<style words>\"\n",
    "}}\n",
    "\n",
    "Post: {(text or '').strip()[:1024]}\n",
    "\"\"\".strip()\n",
    "\n",
    "# ===== JSON Parser =====\n",
    "def safe_extract_json(raw: str) -> dict:\n",
    "    \"\"\"\n",
    "    尽可能从多种常见返回格式中提取 JSON；不合法返回 None。\n",
    "    \"\"\"\n",
    "    if not raw:\n",
    "        return None\n",
    "    txt = raw.strip()\n",
    "\n",
    "    # 去掉代码块包裹\n",
    "    if txt.startswith(\"```\"):\n",
    "        txt = re.sub(r\"^```(json)?\", \"\", txt, flags=re.IGNORECASE).strip()\n",
    "        if txt.endswith(\"```\"):\n",
    "            txt = txt[:-3].strip()\n",
    "\n",
    "    # 如果前后有冗余文字，尝试抽取第一个 {...} 块\n",
    "    # 允许有换行/空白\n",
    "    m = re.search(r\"\\{[\\s\\S]*\\}\", txt)\n",
    "    if m:\n",
    "        txt = m.group(0)\n",
    "\n",
    "    # 单引号 -> 双引号（尽量不误伤）\n",
    "    if txt.count('\"') == 0 and txt.count(\"'\") > 0:\n",
    "        txt = txt.replace(\"’\", \"'\").replace(\"‘\", \"'\")\n",
    "        txt = re.sub(r\"(?<!\\\\)'\", '\"', txt)\n",
    "\n",
    "    # 去掉尾部多余逗号\n",
    "    txt = re.sub(r\",\\s*}\", \"}\", txt)\n",
    "    txt = re.sub(r\",\\s*]\", \"]\", txt)\n",
    "\n",
    "    try:\n",
    "        obj = json.loads(txt)\n",
    "        # 校验字段\n",
    "        if not isinstance(obj, dict):\n",
    "            return None\n",
    "        for k in (\"semantic_view\", \"sentiment_view\", \"linguistic_view\"):\n",
    "            if k not in obj or not isinstance(obj[k], str):\n",
    "                return None\n",
    "        return obj\n",
    "    except Exception:\n",
    "        return None\n",
    "\n",
    "# ===== 兜底生成（确保每条都有）=====\n",
    "def fallback_views_from_text(text: str) -> dict:\n",
    "    \"\"\"\n",
    "    不依赖 LLM 的兜底：生成简短可用的三段字符串，确保字段不为空。\n",
    "    \"\"\"\n",
    "    text = (text or \"\").strip()\n",
    "    # 语义：截取前一句或前 120 字符\n",
    "    semantic = \"\"\n",
    "    if text:\n",
    "        sentence = re.split(r\"(?<=[.!?。！？])\\s+\", text)[0]\n",
    "        semantic = sentence[:180]\n",
    "    if not semantic:\n",
    "        semantic = \"The post contains limited context; the main idea is unclear.\"\n",
    "\n",
    "    # 情感：关键词极简判断\n",
    "    lowered = text.lower()\n",
    "    if any(w in lowered for w in [\"love\", \"great\", \"happy\", \"excited\", \"enjoy\", \"喜欢\", \"高兴\", \"开心\"]):\n",
    "        senti = \"Positive, optimistic.\"\n",
    "    elif any(w in lowered for w in [\"hate\", \"angry\", \"sad\", \"tired\", \"annoyed\", \"讨厌\", \"生气\", \"难过\", \"疲惫\"]):\n",
    "        senti = \"Negative, possibly frustrated or tired.\"\n",
    "    else:\n",
    "        senti = \"Neutral or mixed.\"\n",
    "\n",
    "    # 语言风格：根据表情/大小写/感叹号粗略判断\n",
    "    style_tokens = []\n",
    "    if re.search(r\"[A-Z]{3,}\", text): style_tokens.append(\"emphatic\")\n",
    "    if re.search(r\"[!！]{1,}\", text): style_tokens.append(\"expressive\")\n",
    "    if re.search(r\":[)D]|😂|🤣|😅|🙂|😉\", text): style_tokens.append(\"informal\")\n",
    "    if not style_tokens: style_tokens = [\"conversational\"]\n",
    "    ling = \", \".join(style_tokens)\n",
    "\n",
    "    return {\n",
    "        \"semantic_view\": semantic,\n",
    "        \"sentiment_view\": senti,\n",
    "        \"linguistic_view\": ling\n",
    "    }\n",
    "\n",
    "# ===== 主流程设置 =====\n",
    "MODEL_NAME = \"gpt-4.1-mini\"\n",
    "API_KEY = \"\"               # <<< 用你自己的 key\n",
    "MAX_PER_TYPE = 200\n",
    "INPUT_FILE = \"extended_mbti_dataset_v17.json\"\n",
    "OUTPUT_FILE = \"mbti_sample_with_all_views.json\"\n",
    "MAX_RETRY = 5                                # 最多重试 5 次\n",
    "BASE_SLEEP = 2                               # 初始退避秒\n",
    "\n",
    "# ===== Init =====\n",
    "llm = LLM_API_Wrapper(model=MODEL_NAME, api_key=API_KEY)\n",
    "\n",
    "# ===== Load data =====\n",
    "with open(INPUT_FILE, \"r\", encoding=\"utf-8\") as f:\n",
    "    full_data = json.load(f)\n",
    "\n",
    "# ===== Balance by type =====\n",
    "type_counter = defaultdict(int)\n",
    "selected_samples = []\n",
    "for item in full_data:\n",
    "    mbti = item[\"type\"]\n",
    "    if MAX_PER_TYPE is None or type_counter[mbti] < MAX_PER_TYPE:\n",
    "        selected_samples.append(item.copy())\n",
    "        type_counter[mbti] += 1\n",
    "    if MAX_PER_TYPE and len(selected_samples) >= 16 * MAX_PER_TYPE:\n",
    "        break\n",
    "\n",
    "print(f\"✅ Selected {len(selected_samples)} samples ({MAX_PER_TYPE} per MBTI type)\")\n",
    "\n",
    "# ===== Process =====\n",
    "failed_cases = []\n",
    "for i, item in enumerate(tqdm(selected_samples, desc=\"Generating views\")):\n",
    "    post = item.get(\"posts\") or item.get(\"posts_cleaned\") or \"\"\n",
    "    prompt = build_full_analysis_prompt(post)\n",
    "\n",
    "    views = None\n",
    "    for attempt in range(1, MAX_RETRY + 1):\n",
    "        result = llm.call_api(prompt)\n",
    "        parsed = safe_extract_json(result[0]) if result else None\n",
    "\n",
    "        if parsed:\n",
    "            # 校验非空\n",
    "            if all(isinstance(parsed[k], str) and parsed[k].strip() for k in parsed):\n",
    "                views = parsed\n",
    "                break\n",
    "            else:\n",
    "                # 返回了 JSON 但内容空，换更简单的 backup prompt\n",
    "                prompt = build_fallback_prompt(post)\n",
    "        else:\n",
    "            # 第一次失败后改用 fallback prompt，提高成功率\n",
    "            prompt = build_fallback_prompt(post)\n",
    "\n",
    "        # 指数退避 + 抖动\n",
    "        sleep_s = BASE_SLEEP * attempt + random.uniform(0, 0.5 * attempt)\n",
    "        print(f\"⚠️ Bad/invalid JSON (attempt {attempt}), retry in {sleep_s:.1f}s ...\")\n",
    "        time.sleep(sleep_s)\n",
    "\n",
    "    # 如果还是失败，使用兜底生成（确保每条都有）\n",
    "    if not views:\n",
    "        print(\"❗ Using local fallback generator.\")\n",
    "        views = fallback_views_from_text(post)\n",
    "        failed_cases.append({\"type\": item.get(\"type\"), \"posts\": post})\n",
    "\n",
    "    item[\"semantic_view\"] = views.get(\"semantic_view\", \"\")\n",
    "    item[\"sentiment_view\"] = views.get(\"sentiment_view\", \"\")\n",
    "    item[\"linguistic_view\"] = views.get(\"linguistic_view\", \"\")\n",
    "\n",
    "    # 可选：定期落盘，支持断点续跑\n",
    "    if (i + 1) % 20 == 0:\n",
    "        with open(OUTPUT_FILE, \"w\", encoding=\"utf-8\") as f:\n",
    "            json.dump(selected_samples, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "# ===== Save =====\n",
    "with open(OUTPUT_FILE, \"w\", encoding=\"utf-8\") as f:\n",
    "    json.dump(selected_samples, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"\\n✅ Done! Saved enriched samples to {OUTPUT_FILE}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e53b4b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import json\n",
    "from torch.nn import functional as F\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch_optimizer as optim "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fdf99f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mbti_to_binary(mbti_type):\n",
    "    # INFP → [0, 1, 0, 1] (E/I, S/N, T/F, J/P)\n",
    "    return {\n",
    "        \"ei\": 0 if mbti_type[0] == \"I\" else 1,\n",
    "        \"ns\": 0 if mbti_type[1] == \"S\" else 1,\n",
    "        \"tf\": 0 if mbti_type[2] == \"F\" else 1,\n",
    "        \"jp\": 0 if mbti_type[3] == \"P\" else 1\n",
    "    }\n",
    "class JointMBTIDataset(Dataset):\n",
    "    def __init__(self, data, tokenizer, max_length=256):\n",
    "        self.data = data\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_length = max_length\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = self.data[idx]\n",
    "        # 基础文本\n",
    "        base_text = item.get(\"posts_cleaned\", item[\"posts\"]) or \"\"\n",
    "\n",
    "        # 取三视角（如果没有就空字符串）\n",
    "        semantic = item.get(\"semantic_view\", \"\")\n",
    "        sentiment = item.get(\"sentiment_view\", \"\")\n",
    "        linguistic = item.get(\"linguistic_view\", \"\")\n",
    "\n",
    "        # 拼接：原文 + 三视角\n",
    "        combined_text = base_text\n",
    "        if any([semantic, sentiment, linguistic]):\n",
    "            combined_text = f\"{base_text} [SEP] {semantic} [SEP] {sentiment} [SEP] {linguistic}\"\n",
    "\n",
    "        mbti_type = item[\"type\"]\n",
    "        mbti_label = mbti_to_binary(mbti_type)\n",
    "\n",
    "        encoding = self.tokenizer(\n",
    "            combined_text,\n",
    "            padding=\"max_length\",\n",
    "            truncation=True,\n",
    "            max_length=self.max_length,\n",
    "            return_tensors=\"pt\"\n",
    "        )\n",
    "\n",
    "        return {\n",
    "            \"input_ids\": encoding[\"input_ids\"].squeeze(0),\n",
    "            \"attention_mask\": encoding[\"attention_mask\"].squeeze(0),\n",
    "            \"ei\": torch.tensor(mbti_label[\"ei\"], dtype=torch.float),\n",
    "            \"ns\": torch.tensor(mbti_label[\"ns\"], dtype=torch.float),\n",
    "            \"tf\": torch.tensor(mbti_label[\"tf\"], dtype=torch.float),\n",
    "            \"jp\": torch.tensor(mbti_label[\"jp\"], dtype=torch.float),\n",
    "            \"label\": mbti_type\n",
    "        }\n",
    "class JointMBTIModel(nn.Module):\n",
    "    def __init__(self, encoder_name=\"microsoft/deberta-v3-base\", dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.encoder = AutoModel.from_pretrained(encoder_name)\n",
    "        hidden_size = self.encoder.config.hidden_size\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.classifier_ei = nn.Linear(hidden_size, 1)\n",
    "        self.classifier_ns = nn.Linear(hidden_size, 1)\n",
    "        self.classifier_tf = nn.Linear(hidden_size, 1)\n",
    "        self.classifier_jp = nn.Linear(hidden_size, 1)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask):\n",
    "        output = self.encoder(input_ids, attention_mask=attention_mask)\n",
    "        cls_emb = self.dropout(output.last_hidden_state[:, 0, :])\n",
    "        return {\n",
    "            \"embedding\": cls_emb,  # 可导出向量\n",
    "            \"ei\": self.classifier_ei(cls_emb),\n",
    "            \"ns\": self.classifier_ns(cls_emb),\n",
    "            \"tf\": self.classifier_tf(cls_emb),\n",
    "            \"jp\": self.classifier_jp(cls_emb)\n",
    "        }\n",
    "def compute_mbti_loss(preds, labels):\n",
    "    bce = nn.BCEWithLogitsLoss()\n",
    "    return (\n",
    "        bce(preds[\"ei\"].squeeze(), labels[\"ei\"]) +\n",
    "        bce(preds[\"ns\"].squeeze(), labels[\"ns\"]) +\n",
    "        bce(preds[\"tf\"].squeeze(), labels[\"tf\"]) +\n",
    "        bce(preds[\"jp\"].squeeze(), labels[\"jp\"])\n",
    "    ) / 4\n",
    "    \n",
    "def train_model(model, train_loader, val_loader, optimizer, device, epochs=5, save_path=\"kaggle_best_fem.pt\"):\n",
    "    model.to(device)\n",
    "    best_val_acc = 0.0\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        model.train()\n",
    "        total_loss = 0\n",
    "\n",
    "        for batch in tqdm(train_loader, desc=f\"[Train] Epoch {epoch+1}/{epochs}\"):\n",
    "            input_ids = batch[\"input_ids\"].to(device)\n",
    "            attention_mask = batch[\"attention_mask\"].to(device)\n",
    "            labels = {k: batch[k].to(device) for k in [\"ei\", \"ns\", \"tf\", \"jp\"]}\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
    "            loss = compute_mbti_loss(outputs, labels)\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "            optimizer.step()\n",
    "            total_loss += loss.item()\n",
    "\n",
    "        avg_loss = total_loss / len(train_loader)\n",
    "        print(f\"\\nEpoch {epoch+1} - Train Loss: {avg_loss:.4f}\")\n",
    "\n",
    "        # 验证阶段\n",
    "        model.eval()\n",
    "        correct_ei = correct_ns = correct_tf = correct_jp = correct_all = 0\n",
    "        total = 0\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for batch in tqdm(val_loader, desc=f\"[Val]   Epoch {epoch+1}\"):\n",
    "                input_ids = batch[\"input_ids\"].to(device)\n",
    "                attention_mask = batch[\"attention_mask\"].to(device)\n",
    "                labels = {k: batch[k].to(device) for k in [\"ei\", \"ns\", \"tf\", \"jp\"]}\n",
    "\n",
    "                outputs = model(input_ids=input_ids, attention_mask=attention_mask)\n",
    "\n",
    "                pred_ei = torch.sigmoid(outputs[\"ei\"]).round()\n",
    "                pred_ns = torch.sigmoid(outputs[\"ns\"]).round()\n",
    "                pred_tf = torch.sigmoid(outputs[\"tf\"]).round()\n",
    "                pred_jp = torch.sigmoid(outputs[\"jp\"]).round()\n",
    "\n",
    "                correct_ei += (pred_ei.squeeze() == labels[\"ei\"]).sum().item()\n",
    "                correct_ns += (pred_ns.squeeze() == labels[\"ns\"]).sum().item()\n",
    "                correct_tf += (pred_tf.squeeze() == labels[\"tf\"]).sum().item()\n",
    "                correct_jp += (pred_jp.squeeze() == labels[\"jp\"]).sum().item()\n",
    "\n",
    "                correct_all += (\n",
    "                    (pred_ei.squeeze() == labels[\"ei\"]) &\n",
    "                    (pred_ns.squeeze() == labels[\"ns\"]) &\n",
    "                    (pred_tf.squeeze() == labels[\"tf\"]) &\n",
    "                    (pred_jp.squeeze() == labels[\"jp\"])\n",
    "                ).sum().item()\n",
    "\n",
    "                total += input_ids.size(0)\n",
    "\n",
    "        acc_ei = correct_ei / total\n",
    "        acc_ns = correct_ns / total\n",
    "        acc_tf = correct_tf / total\n",
    "        acc_jp = correct_jp / total\n",
    "        acc_all = correct_all / total\n",
    "\n",
    "        print(f\"Validation Accuracy:\")\n",
    "        print(f\"  EI: {acc_ei:.2%} | NS: {acc_ns:.2%} | TF: {acc_tf:.2%} | JP: {acc_jp:.2%}\")\n",
    "        print(f\"  4D Match: {acc_all:.2%}\")\n",
    "\n",
    "        # 保存验证集最佳模型\n",
    "        if acc_all > best_val_acc:\n",
    "            best_val_acc = acc_all\n",
    "            torch.save(model.state_dict(), save_path)\n",
    "            print(f\"✅ Best model saved to {save_path} (4D match: {acc_all:.2%})\")\n",
    "def visualize_embeddings(model, dataloader, device):\n",
    "    model.eval()\n",
    "    all_embeds, all_labels = [], []\n",
    "    with torch.no_grad():\n",
    "        for batch in dataloader:\n",
    "            input_ids = batch[\"input_ids\"].to(device)\n",
    "            attention_mask = batch[\"attention_mask\"].to(device)\n",
    "            output = model(input_ids=input_ids, attention_mask=attention_mask)\n",
    "            all_embeds.append(output[\"embedding\"].cpu())\n",
    "            all_labels.extend(batch[\"label\"])\n",
    "    embeddings = torch.cat(all_embeds).numpy()\n",
    "\n",
    "    tsne = TSNE(n_components=2, random_state=42).fit_transform(embeddings)\n",
    "    df = pd.DataFrame(tsne, columns=[\"x\", \"y\"])\n",
    "    df[\"label\"] = all_labels\n",
    "\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    for label in sorted(set(all_labels)):\n",
    "        subset = df[df[\"label\"] == label]\n",
    "        plt.scatter(subset[\"x\"], subset[\"y\"], label=label, alpha=0.6)\n",
    "    plt.legend()\n",
    "    plt.title(\"FEM Personality Embedding Space (t-SNE)\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ee1d146",
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    # 加载数据\n",
    "    with open(\"mbti_sample_with_all_views.json\", \"r\", encoding=\"utf-8\") as f:\n",
    "        full_data = json.load(f)\n",
    "\n",
    "    # 拆分训练集 / 验证集\n",
    "    train_data, val_data = train_test_split(full_data, test_size=0.2, random_state=42)\n",
    "\n",
    "    # 初始化 tokenizer\n",
    "    tokenizer = AutoTokenizer.from_pretrained(\"microsoft/deberta-v3-large\")\n",
    "\n",
    "    # 初始化 dataset\n",
    "    train_dataset = JointMBTIDataset(train_data, tokenizer)  # 支持list\n",
    "    val_dataset = JointMBTIDataset(val_data, tokenizer)\n",
    "\n",
    "    # 初始化 dataloader\n",
    "    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)\n",
    "    val_loader = DataLoader(val_dataset, batch_size=16)\n",
    "\n",
    "    # 模型和优化器\n",
    "    model = JointMBTIModel(\"microsoft/deberta-v3-large\")\n",
    "    #optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
    "    optimizer = optim.AdamP(model.parameters(), lr=1e-5, weight_decay=0.01)\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    # 训练并保存最佳模型\n",
    "    train_model(model, train_loader, val_loader, optimizer, device, epochs=15)\n",
    "\n",
    "    # 可视化嵌入空间（只用验证集）\n",
    "    visualize_embeddings(model, val_loader, device)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
