{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.motif import init_repro\n",
    "\n",
    "init_repro(42, deterministic=True)\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"./utils\")\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import wandb\n",
    "import time\n",
    "import os\n",
    "import pickle\n",
    "import clip\n",
    "\n",
    "from transformers import (\n",
    "    CLIPProcessor,\n",
    "    CLIPModel,\n",
    "    CLIPVisionModelWithProjection,\n",
    "    CLIPTokenizer,\n",
    "    CLIPTextModelWithProjection,\n",
    "    AutoTokenizer,\n",
    "    AutoModel,\n",
    "    AutoProcessor,\n",
    ")\n",
    "\n",
    "from utils.video_embedder import VideoEmbedder, Create_Concepts\n",
    "from utils.motif import MoTIF, CBMTransformer, mean_cbm\n",
    "from utils.explanations import explain_instance\n",
    "import core.vision_encoder.pe as pe\n",
    "import core.vision_encoder.transforms as pe_transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training hyperparameters\n",
    "num_epochs = 100\n",
    "batch_size = 32\n",
    "lse_tau = 1.0\n",
    "l1_lambda = 1e-3\n",
    "lambda_sparse = 1e-3\n",
    "lr = 1e-3\n",
    "transformer_layers = 1\n",
    "diagonal_attention = True\n",
    "enforce_nonneg = True\n",
    "class_weights = True\n",
    "weight_decay = 1e-2\n",
    "d = 1\n",
    "\n",
    "# Dataset settings\n",
    "test_split = \"s1\"\n",
    "window_size = 32\n",
    "dataset = \"breakfast\"\n",
    "random = True\n",
    "use_wandb = True\n",
    "clip_model = \"res50\"\n",
    "\n",
    "# Map dataset name\n",
    "dataset_map = {\n",
    "    \"breakfast\": \"Breakfast\",\n",
    "    \"ucf101\": \"UCF101\",\n",
    "    \"hmdb51\": \"HMDB\",\n",
    "    \"something2\": \"Something2\"\n",
    "}\n",
    "dataset_name = dataset_map.get(dataset, dataset)\n",
    "\n",
    "folder_path = [f\"../Datasets/{dataset_name}/Video_data\"]\n",
    "output_dir = \"../Embeddings/Datasets\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load CLIP model and processor\n",
    "if clip_model == \"b32\":\n",
    "    model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").eval()\n",
    "    processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\", use_fast=False)\n",
    "    embedd_path = f\"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_b32.pkl\"\n",
    "    clip_name = \"clip\"\n",
    "elif clip_model == \"b16\":\n",
    "    model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\").eval()\n",
    "    processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\", use_fast=False)\n",
    "    embedd_path = f\"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_b16.pkl\"\n",
    "    clip_name = \"clip\"\n",
    "elif clip_model == \"l14\":\n",
    "    model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\").eval()\n",
    "    processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\", use_fast=False)\n",
    "    embedd_path = f\"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_l14.pkl\"\n",
    "    clip_name = \"clip\"\n",
    "elif clip_model == \"res50\":\n",
    "    model, preprocess = clip.load(\"RN50\", device=\"cpu\")\n",
    "    processor = preprocess\n",
    "    embedd_path = f\"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_res50.pkl\"\n",
    "    clip_name = \"res50\"\n",
    "elif clip_model == \"clip4clip\":\n",
    "    model = CLIPVisionModelWithProjection.from_pretrained(\"Searchium-ai/clip4clip-webvid150k\").eval()\n",
    "    model_text = CLIPTextModelWithProjection.from_pretrained(\"Searchium-ai/clip4clip-webvid150k\")\n",
    "    processor = CLIPTokenizer.from_pretrained(\"Searchium-ai/clip4clip-webvid150k\")\n",
    "    embedd_path = f\"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_clip4clip.pkl\"\n",
    "    clip_name = \"clip4clip\"\n",
    "elif clip_model == \"siglip\":\n",
    "    model = AutoModel.from_pretrained(\"google/siglip-base-patch16-224\")\n",
    "    processor = AutoProcessor.from_pretrained(\"google/siglip-base-patch16-224\")\n",
    "    embedd_path = f\"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_siglip.pkl\"\n",
    "    clip_name = \"siglip\"\n",
    "elif clip_model == \"siglipl14\":\n",
    "    model = AutoModel.from_pretrained(\"google/siglip-so400m-patch14-384\")\n",
    "    processor = AutoProcessor.from_pretrained(\"google/siglip-so400m-patch14-384\")\n",
    "    embedd_path = f\"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_siglipl14.pkl\"\n",
    "    clip_name = \"siglipl14\"\n",
    "elif clip_model == \"pe-l14\":\n",
    "    model = pe.CLIP.from_config(\"PE-Core-L14-336\", pretrained=True)\n",
    "    processor = pe_transformer.get_image_transform(model.image_size)\n",
    "    tokenizer = pe_transformer.get_text_tokenizer(model.context_length)\n",
    "    clip_name = \"pe-l14\"\n",
    "    embedd_path = f\"../Embeddings/Videos/{dataset_name}/{random}_{window_size}_clip_pe-l14.pkl\"\n",
    "\n",
    "else:\n",
    "    model = None\n",
    "    processor = None\n",
    "    model_text = None\n",
    "\n",
    "# Initialize embedder\n",
    "embedder = VideoEmbedder(clip_name, model, processor)\n",
    "embedder.dataset_name = dataset\n",
    "\n",
    "# Load or create embeddings\n",
    "if os.path.exists(embedd_path):\n",
    "    with open(embedd_path, 'rb') as f:\n",
    "        embedder = pickle.load(f)\n",
    "    print(f\"Loaded existing embedder from {embedd_path}\")\n",
    "else:\n",
    "    embedder.process_data(\n",
    "        folder_path,\n",
    "        window_size=window_size,\n",
    "        output_path=output_dir,\n",
    "        save_intermediate=False,\n",
    "    )\n",
    "    with open(embedd_path, \"wb\") as f:\n",
    "        pickle.dump(embedder, f)\n",
    "\n",
    "# Initialize concepts\n",
    "if clip_model == \"clip4clip\":\n",
    "    concepts = Create_Concepts(clip_name, model_text, processor)\n",
    "elif clip_model == \"pe-l14\":\n",
    "    concepts = Create_Concepts(clip_name, model, tokenizer)\n",
    "else:\n",
    "    concepts = Create_Concepts(clip_name, model, processor)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated 1989 synthetic samples\n",
      "Sequence shape: (27, 1024)\n",
      "Concept embeddings: (37, 1024)\n",
      "Number of concepts: 37\n",
      "Label distribution: {np.str_('class_0'): np.int64(398), np.str_('class_1'): np.int64(398), np.str_('class_2'): np.int64(398), np.str_('class_3'): np.int64(398), np.str_('class_4'): np.int64(397)}\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def create_artificial_concept_embeddings(dim=1024, num_concepts=None, seed=42):\n",
    "    \"\"\"Create concept embeddings (K, dim) representing temporal patterns.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    \n",
    "    # Comprehensive list of temporal pattern concepts\n",
    "    all_concept_names = [\n",
    "        # Basic trends\n",
    "        \"ascending\", \"descending\", \"constant\",\n",
    "        \n",
    "        # Curved patterns\n",
    "        \"u_shaped\", \"inverted_u\", \"s_shaped\", \"inverted_s\",\n",
    "        \n",
    "        # Rate of change\n",
    "        \"increasing_rate\", \"decreasing_rate\", \"accelerating\", \"decelerating\",\n",
    "        \n",
    "        # Periodic patterns\n",
    "        \"periodic\", \"periodic_fast\", \"periodic_slow\", \"periodic_phase_shift\",\n",
    "        \n",
    "        # Step functions\n",
    "        \"step_up\", \"step_down\", \"step_up_down\", \"step_down_up\",\n",
    "        \n",
    "        # Exponential patterns\n",
    "        \"exponential_growth\", \"exponential_decay\", \"logarithmic\",\n",
    "        \n",
    "        # Multi-peak patterns\n",
    "        \"double_peak\", \"triple_peak\", \"sawtooth\", \"square_wave\",\n",
    "        \n",
    "        # Asymmetric patterns\n",
    "        \"fast_rise_slow_fall\", \"slow_rise_fast_fall\", \"plateau_start\", \"plateau_end\",\n",
    "        \n",
    "        # Complex patterns\n",
    "        \"oscillating_decay\", \"oscillating_growth\", \"chirp\", \"spike\",\n",
    "        \n",
    "        # Edge cases\n",
    "        \"random\", \"noise\", \"sparse\"\n",
    "    ]\n",
    "    \n",
    "    # Use all concepts if num_concepts is None, otherwise use first N\n",
    "    if num_concepts is None:\n",
    "        concept_names = all_concept_names\n",
    "    else:\n",
    "        concept_names = all_concept_names[:num_concepts]\n",
    "    \n",
    "    concept_embeddings = []\n",
    "    \n",
    "    for i, name in enumerate(concept_names):\n",
    "        x = np.linspace(0, 1, dim)\n",
    "        t = np.linspace(0, 4 * np.pi, dim)\n",
    "        \n",
    "        # Basic trends\n",
    "        if name == \"ascending\":\n",
    "            emb = np.linspace(-0.5, 0.5, dim)\n",
    "        elif name == \"descending\":\n",
    "            emb = np.linspace(0.5, -0.5, dim)\n",
    "        elif name == \"constant\":\n",
    "            emb = np.ones(dim) * 0.1\n",
    "        \n",
    "        # Curved patterns\n",
    "        elif name == \"u_shaped\":\n",
    "            mid = dim // 2\n",
    "            emb = np.concatenate([\n",
    "                np.linspace(0.5, -0.5, mid),\n",
    "                np.linspace(-0.5, 0.5, dim - mid)\n",
    "            ])\n",
    "        elif name == \"inverted_u\":\n",
    "            mid = dim // 2\n",
    "            emb = np.concatenate([\n",
    "                np.linspace(-0.5, 0.5, mid),\n",
    "                np.linspace(0.5, -0.5, dim - mid)\n",
    "            ])\n",
    "        elif name == \"s_shaped\":\n",
    "            emb = 0.5 * (np.tanh(5 * (x - 0.5)) + 1) - 0.5\n",
    "        elif name == \"inverted_s\":\n",
    "            emb = -0.5 * (np.tanh(5 * (x - 0.5)) + 1) + 0.5\n",
    "        \n",
    "        # Rate of change\n",
    "        elif name == \"increasing_rate\":\n",
    "            emb = 0.5 * x ** 2 - 0.25\n",
    "        elif name == \"decreasing_rate\":\n",
    "            emb = -0.5 * x ** 2 + 0.25\n",
    "        elif name == \"accelerating\":\n",
    "            emb = 0.4 * x ** 3 - 0.2\n",
    "        elif name == \"decelerating\":\n",
    "            emb = -0.4 * x ** 3 + 0.2\n",
    "        \n",
    "        # Periodic patterns\n",
    "        elif name == \"periodic\":\n",
    "            emb = 0.3 * np.sin(t)\n",
    "        elif name == \"periodic_fast\":\n",
    "            emb = 0.3 * np.sin(2 * t)\n",
    "        elif name == \"periodic_slow\":\n",
    "            emb = 0.3 * np.sin(0.5 * t)\n",
    "        elif name == \"periodic_phase_shift\":\n",
    "            emb = 0.3 * np.sin(t + np.pi / 2)\n",
    "        \n",
    "        # Step functions\n",
    "        elif name == \"step_up\":\n",
    "            emb = np.where(x < 0.5, -0.3, 0.3)\n",
    "        elif name == \"step_down\":\n",
    "            emb = np.where(x < 0.5, 0.3, -0.3)\n",
    "        elif name == \"step_up_down\":\n",
    "            emb = np.where(x < 0.33, -0.3, np.where(x < 0.67, 0.3, -0.3))\n",
    "        elif name == \"step_down_up\":\n",
    "            emb = np.where(x < 0.33, 0.3, np.where(x < 0.67, -0.3, 0.3))\n",
    "        \n",
    "        # Exponential patterns\n",
    "        elif name == \"exponential_growth\":\n",
    "            emb = 0.3 * (np.exp(2 * x) - 1) / (np.exp(2) - 1) - 0.15\n",
    "        elif name == \"exponential_decay\":\n",
    "            emb = 0.3 * (np.exp(-2 * x)) - 0.15\n",
    "        elif name == \"logarithmic\":\n",
    "            emb = 0.3 * np.log1p(10 * x) / np.log(11) - 0.15\n",
    "        \n",
    "        # Multi-peak patterns\n",
    "        elif name == \"double_peak\":\n",
    "            emb = 0.3 * (np.sin(2 * t) + 0.5 * np.sin(4 * t))\n",
    "        elif name == \"triple_peak\":\n",
    "            emb = 0.25 * (np.sin(3 * t) + 0.5 * np.sin(6 * t))\n",
    "        elif name == \"sawtooth\":\n",
    "            emb = 0.3 * (2 * (t / (2 * np.pi) % 1) - 1)\n",
    "        elif name == \"square_wave\":\n",
    "            emb = 0.3 * np.sign(np.sin(t))\n",
    "        \n",
    "        # Asymmetric patterns\n",
    "        elif name == \"fast_rise_slow_fall\":\n",
    "            rise = 0.4 * np.exp(5 * x[:dim//2]) / np.exp(2.5) - 0.2\n",
    "            fall = 0.4 * np.exp(-2 * (x[dim//2:] - 0.5)) - 0.2\n",
    "            emb = np.concatenate([rise, fall])\n",
    "        elif name == \"slow_rise_fast_fall\":\n",
    "            rise = 0.4 * np.exp(2 * x[:dim//2]) / np.exp(1) - 0.2\n",
    "            fall = 0.4 * np.exp(-5 * (x[dim//2:] - 0.5)) - 0.2\n",
    "            emb = np.concatenate([rise, fall])\n",
    "        elif name == \"plateau_start\":\n",
    "            plateau_len = int(0.3 * dim)\n",
    "            emb = np.concatenate([np.full(plateau_len, 0.2), np.linspace(0.2, -0.2, dim - plateau_len)])\n",
    "        elif name == \"plateau_end\":\n",
    "            plateau_len = int(0.7 * dim)\n",
    "            emb = np.concatenate([np.linspace(-0.2, 0.2, plateau_len), np.full(dim - plateau_len, 0.2)])\n",
    "        \n",
    "        # Complex patterns\n",
    "        elif name == \"oscillating_decay\":\n",
    "            emb = 0.3 * np.sin(2 * t) * np.exp(-x)\n",
    "        elif name == \"oscillating_growth\":\n",
    "            emb = 0.3 * np.sin(2 * t) * (1 - np.exp(-x))\n",
    "        elif name == \"chirp\":\n",
    "            freq = 0.5 + 2 * x\n",
    "            emb = 0.3 * np.sin(2 * np.pi * freq * x)\n",
    "        elif name == \"spike\":\n",
    "            emb = 0.5 * np.exp(-((x - 0.5) ** 2) / 0.01) - 0.25\n",
    "        \n",
    "        # Edge cases\n",
    "        elif name == \"random\":\n",
    "            emb = rng.normal(0, 0.3, dim)\n",
    "        elif name == \"noise\":\n",
    "            emb = rng.normal(0, 0.1, dim)\n",
    "        elif name == \"sparse\":\n",
    "            emb = rng.choice([-0.5, 0, 0.5], size=dim, p=[0.1, 0.8, 0.1])\n",
    "        else:\n",
    "            # Default: random pattern\n",
    "            emb = rng.normal(0, 0.3, dim)\n",
    "        \n",
    "        if name != \"noise\" and name != \"random\":\n",
    "            emb = emb + rng.normal(0, 0.05, dim)\n",
    "        \n",
    "        emb = emb / (np.linalg.norm(emb) + 1e-9)\n",
    "        concept_embeddings.append(emb)\n",
    "    \n",
    "    return np.array(concept_embeddings), concept_names\n",
    "\n",
    "\n",
    "def generate_order_sensitive_data(samples, dim=1024, num_classes=5, seed=42, pattern_strength=0.5, shuffled=False):\n",
    "    \"\"\"Generate synthetic sequences with order-dependent temporal patterns.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    X = {}\n",
    "    y = []\n",
    "\n",
    "    data_keys = list(samples.keys())\n",
    "    num_pattern_dims = max(1, int(dim * pattern_strength))  # Use multiple dimensions for patterns\n",
    "    \n",
    "    # Class descriptions:\n",
    "    # class_0: Ascending linear pattern (monotonically increasing from -1 to 1)\n",
    "    # class_1: Descending linear pattern (monotonically decreasing from 1 to -1)\n",
    "    # class_2: U-shaped pattern (decreasing then increasing, valley shape)\n",
    "    # class_3: Inverted U-shaped pattern (increasing then decreasing, peak shape)s\n",
    "    # class_4: Sinusoidal pattern (periodic oscillation with sin function)\n",
    "\n",
    "    for i, key in enumerate(data_keys):\n",
    "        n = len(samples[key])\n",
    "        seq = rng.uniform(-1, 1, (n, dim))\n",
    "        t = np.linspace(0, 2 * np.pi, n)\n",
    "        \n",
    "        if i % num_classes == 0:\n",
    "            pattern = np.linspace(-1, 1, n)\n",
    "            label = \"class_0\"\n",
    "        elif i % num_classes == 1:\n",
    "            pattern = np.linspace(1, -1, n)\n",
    "            label = \"class_1\"\n",
    "        elif i % num_classes == 2:\n",
    "            mid = n // 2\n",
    "            pattern = np.concatenate([\n",
    "                np.linspace(1, -1, mid),\n",
    "                np.linspace(-1, 1, n - mid)\n",
    "            ])\n",
    "            label = \"class_2\"\n",
    "        elif i % num_classes == 3:\n",
    "            mid = n // 2\n",
    "            pattern = np.concatenate([\n",
    "                np.linspace(-1, 1, mid),\n",
    "                np.linspace(1, -1, n - mid)\n",
    "            ])\n",
    "            label = \"class_3\"\n",
    "        else:\n",
    "            pattern = np.sin(t)\n",
    "            label = \"class_4\"\n",
    "\n",
    "        for dim_idx in range(num_pattern_dims):\n",
    "            amplitude = 0.5 + 0.5 * (dim_idx % 3) / 3\n",
    "            seq[:, dim_idx] = pattern * amplitude + 0.1 * rng.normal(0, 1, n)\n",
    "        \n",
    "        seq = seq + rng.normal(0, 0.1, (n, dim))\n",
    "        \n",
    "        if shuffled:\n",
    "            seq = np.random.permutation(seq)\n",
    "        \n",
    "        X[key] = seq.astype(np.float32)\n",
    "        y.append(label)\n",
    "\n",
    "    return X, np.array(y)\n",
    "\n",
    "\n",
    "# Generate synthetic sequences with temporal patterns\n",
    "X, y = generate_order_sensitive_data(\n",
    "    embedder.video_embeddings,\n",
    "    dim=1024,\n",
    "    num_classes=5,\n",
    "    seed=42,\n",
    "    pattern_strength=0.3,\n",
    "    shuffled = False,\n",
    ")\n",
    "\n",
    "# Create concept embeddings for temporal patterns\n",
    "concept_embeddings, concept_names = create_artificial_concept_embeddings(\n",
    "    dim=1024,\n",
    "    num_concepts=None,\n",
    "    seed=42\n",
    ")\n",
    "\n",
    "print(f\"Generated {len(X)} synthetic samples\")\n",
    "print(f\"Sequence shape: {list(X.values())[0].shape}\")\n",
    "print(f\"Concept embeddings: {concept_embeddings.shape}\")\n",
    "print(f\"Number of concepts: {len(concept_names)}\")\n",
    "print(f\"Label distribution: {dict(zip(*np.unique(y, return_counts=True)))}\")\n",
    "\n",
    "# Replace embedder data with synthetic data\n",
    "embedder.video_embeddings = X\n",
    "embedder.labels = y\n",
    "\n",
    "# Set synthetic concept embeddings\n",
    "concepts.text_embeddings = torch.tensor(concept_embeddings, dtype=torch.float32)\n",
    "concepts.text_concepts = concept_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize and train MoTIF\n",
    "cbm_model = MoTIF(embedder, concepts)\n",
    "cbm_model.preprocess(dataset, info=test_split)\n",
    "\n",
    "if use_wandb:\n",
    "    time_now = time.strftime(\"%Y-%m-%d_%H-%M-%S\", time.localtime())\n",
    "    name = f\"{embedder.dataset_name}_clip_basic_{time_now}\"\n",
    "    wandb_run = wandb.init(project=\"motif\", name=name)\n",
    "    cbm_model.zero_shot(concepts, wandb_run=wandb_run)\n",
    "    wandb_run.log({\n",
    "        \"backbone\": clip_name,\n",
    "        \"window_size\": window_size,\n",
    "        \"dataset\": dataset,\n",
    "        \"random\": random,\n",
    "        \"test_split\": test_split,\n",
    "        \"d\": d,\n",
    "    })\n",
    "else:\n",
    "    wandb_run = None\n",
    "    cbm_model.zero_shot(concepts)\n",
    "\n",
    "cbm_model.model = CBMTransformer(\n",
    "    cbm_model.num_concepts,\n",
    "    num_classes=cbm_model.num_classes,\n",
    "    transformer_layers=transformer_layers,\n",
    "    lse_tau=lse_tau,\n",
    "    dimension=d,\n",
    "    diagonal_attention=diagonal_attention,\n",
    ")\n",
    "\n",
    "cbm_model.train_model(\n",
    "    num_epochs=num_epochs,\n",
    "    l1_lambda=l1_lambda,\n",
    "    lambda_sparse=lambda_sparse,\n",
    "    lr=lr,\n",
    "    batch_size=batch_size,\n",
    "    enforce_nonneg=enforce_nonneg,\n",
    "    class_weights=class_weights,\n",
    "    wandb_run=wandb_run,\n",
    ")\n",
    "\n",
    "mean_cbm(cbm_model, wandb_run=wandb_run)\n",
    "\n",
    "if use_wandb:\n",
    "    wandb_run.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Trial 1: Original=0.8697, Shuffled=0.2077\n",
      "Trial 2: Original=0.8697, Shuffled=0.2007\n",
      "Trial 3: Original=0.8697, Shuffled=0.2148\n",
      "Trial 4: Original=0.8697, Shuffled=0.2148\n",
      "Trial 5: Original=0.8697, Shuffled=0.2148\n",
      "\n",
      "Average: Original=0.8697, Shuffled=0.2106\n"
     ]
    }
   ],
   "source": [
    "# Evaluate temporal learning: compare accuracy on original vs shuffled sequences\n",
    "results = {}\n",
    "for trial in range(5):\n",
    "    correct_original = 0\n",
    "    correct_shuffled = 0\n",
    "    \n",
    "    for i in range(len(cbm_model.X_test)):\n",
    "        video = cbm_model.X_test[i]\n",
    "        true_idx = cbm_model.y_test[i]\n",
    "        \n",
    "        res = explain_instance(cbm_model.model, video)\n",
    "        class_pred = res[\"target_class\"]\n",
    "        \n",
    "        shuffled_video = torch.from_numpy(np.random.permutation(video)).to(\"cuda:0\")\n",
    "        logits, _, _, _ = cbm_model.model(shuffled_video)\n",
    "        pred_shuffled = int(logits.argmax(dim=1).item())\n",
    "        \n",
    "        if true_idx == class_pred:\n",
    "            correct_original += 1\n",
    "        if true_idx == pred_shuffled:\n",
    "            correct_shuffled += 1\n",
    "    \n",
    "    acc_original = correct_original / len(cbm_model.X_test)\n",
    "    acc_shuffled = correct_shuffled / len(cbm_model.X_test)\n",
    "    \n",
    "    print(f\"Trial {trial+1}: Original={acc_original:.4f}, Shuffled={acc_shuffled:.4f}\")\n",
    "    results[trial] = [acc_original, acc_shuffled]\n",
    "\n",
    "print(f\"\\nAverage: Original={np.mean([r[0] for r in results.values()]):.4f}, \"\n",
    "      f\"Shuffled={np.mean([r[1] for r in results.values()]):.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "motif",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
