{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1911ea3f-d916-4d89-9dab-822068b8e2f0",
   "metadata": {},
   "source": [
    "### Visualization of the Synthetic Dataset\n",
    "This script visualizes the creation of the synthetic dataset that is used in the paper and saves the images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35c3e2b6-e7b2-46a2-be4f-4ff7c5dbc76c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "# ----------- Configuration -----------\n",
    "BIT_LENGTH = 784\n",
    "NUM_CLASSES = 1\n",
    "SAMPLES_PER_CLASS = 10\n",
    "FIXED_BITS_RANGE = (20, 21)  # Min and max number of fixed bits per class\n",
    "\n",
    "# ----------- Synthetic Data Generation -----------\n",
    "def generate_synthetic_data():\n",
    "    class_fixed_bits = {}\n",
    "    data = []\n",
    "    labels = []\n",
    "\n",
    "    for cls in range(NUM_CLASSES):\n",
    "        num_fixed = np.random.randint(*FIXED_BITS_RANGE)\n",
    "        fixed_indices = np.random.choice(BIT_LENGTH, size=num_fixed, replace=False)\n",
    "        fixed_values = np.random.randint(0, 2, size=num_fixed)\n",
    "        class_fixed_bits[cls] = (fixed_indices, fixed_values)\n",
    "\n",
    "        for _ in range(SAMPLES_PER_CLASS):\n",
    "            sample = np.random.randint(0, 2, size=BIT_LENGTH)\n",
    "            sample[fixed_indices] = fixed_values\n",
    "            data.append(sample)\n",
    "            labels.append(cls)\n",
    "\n",
    "    return np.array(data), np.array(labels), class_fixed_bits\n",
    "\n",
    "# ----------- Visualization Functions -----------\n",
    "\n",
    "def visualize_class_sample(sample, class_id):\n",
    "    \"\"\"Displays a single sample from the class as a binary image.\"\"\"\n",
    "    image = sample.reshape(28, 28)\n",
    "    plt.imshow(image, cmap='gray', interpolation='nearest')\n",
    "    # plt.title(f\"Sample from Class {class_id}\")\n",
    "    plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"synthetic_sample.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "def visualize_fixed_positions(fixed_indices, fixed_values):\n",
    "    \"\"\"Heatmap zeigt feste Positionen als weiß (1) oder schwarz (0) mit hellrotem Hintergrund.\"\"\"\n",
    "    mask = np.full(BIT_LENGTH, np.nan)  # NaN für nicht fixierte Bits\n",
    "    mask[fixed_indices] = fixed_values  # Setze fixierte Bits auf 0 oder 1\n",
    "    mask = mask.reshape(28, 28)\n",
    "\n",
    "    cmap = plt.cm.gray\n",
    "    cmap.set_bad(color='#ffd9d9')  # Helles Rot für NaN (nicht fixierte Bits)\n",
    "\n",
    "    plt.imshow(mask, cmap=cmap, vmin=0, vmax=1)\n",
    "    plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"synthetic_pattern.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "def bit_frequency_map(samples):\n",
    "    \"\"\"Shows which bits are consistently set in a class.\"\"\"\n",
    "    freq = samples.mean(axis=0).reshape(28, 28)\n",
    "    im = plt.imshow(freq, cmap='viridis', interpolation='nearest')\n",
    "    # plt.title(\"Bit Frequency Heatmap – Class Sample Distribution\")\n",
    "    cbar = plt.colorbar(im)\n",
    "    cbar.set_label('P(1)', fontsize=26)\n",
    "    cbar.ax.tick_params(labelsize=24)\n",
    "    # cbar.ax.set_aspect(10)   \n",
    "    plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"synthetic_frequency.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "def tsne_plot(data, labels):\n",
    "    \"\"\"Visualizes global class structure via t-SNE.\"\"\"\n",
    "    print(\"Running t-SNE... (this may take a moment)\")\n",
    "    tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, init='random', random_state=42)\n",
    "    data_2d = tsne.fit_transform(data)\n",
    "\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    sns.scatterplot(x=data_2d[:, 0], y=data_2d[:, 1], hue=labels, palette='tab10', s=10, linewidth=0)\n",
    "    # plt.title(\"t-SNE Projection of Synthetic Bitstrings\")\n",
    "    plt.axis('off')\n",
    "    plt.legend(title='Class', bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "# ----------- Run Everything -----------\n",
    "if __name__ == \"__main__\":\n",
    "    # Generate data\n",
    "    data, labels, fixed_bits_info = generate_synthetic_data()\n",
    "\n",
    "    # Choose a class to visualize\n",
    "    class_id = 0\n",
    "    class_samples = data[labels == class_id]\n",
    "    fixed_indices, _ = fixed_bits_info[class_id]\n",
    "\n",
    "    # Paper-relevant plots\n",
    "    visualize_class_sample(class_samples[0], class_id)\n",
    "    visualize_class_sample(class_samples[1], class_id)\n",
    "    fixed_indices, fixed_values = fixed_bits_info[class_id]\n",
    "    visualize_fixed_positions(fixed_indices, fixed_values)\n",
    "    # visualize_fixed_positions(fixed_indices)\n",
    "    bit_frequency_map(class_samples)\n",
    "    # tsne_plot(data, labels)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa1ebffa-0bb4-4388-9487-435cb7240873",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.manifold import TSNE\n",
    "import matplotlib.patches as patches\n",
    "import os\n",
    "\n",
    "# ----------- Configuration -----------\n",
    "BIT_LENGTH = 784\n",
    "NUM_CLASSES = 1\n",
    "SAMPLES_PER_CLASS = 10\n",
    "FIXED_BITS_RANGE = (20, 21)  # Min and max number of fixed bits per class\n",
    "\n",
    "# ----------- Synthetic Data Generation -----------\n",
    "def generate_synthetic_data():\n",
    "    class_fixed_bits = {}\n",
    "    data = []\n",
    "    labels = []\n",
    "\n",
    "    for cls in range(NUM_CLASSES):\n",
    "        num_fixed = np.random.randint(*FIXED_BITS_RANGE)\n",
    "        fixed_indices = np.random.choice(BIT_LENGTH, size=num_fixed, replace=False)\n",
    "        fixed_values = np.random.randint(0, 2, size=num_fixed)\n",
    "        class_fixed_bits[cls] = (fixed_indices, fixed_values)\n",
    "\n",
    "        for _ in range(SAMPLES_PER_CLASS):\n",
    "            sample = np.random.randint(0, 2, size=BIT_LENGTH)\n",
    "            sample[fixed_indices] = fixed_values\n",
    "            data.append(sample)\n",
    "            labels.append(cls)\n",
    "\n",
    "    return np.array(data), np.array(labels), class_fixed_bits\n",
    "\n",
    "# ----------- Visualization Functions -----------\n",
    "\n",
    "def visualize_class_sample(sample, class_id, fixed_indices=None, save_as=None):\n",
    "    \"\"\"Displays a single sample from the class as a binary image with fixed bits highlighted.\"\"\"\n",
    "    image = sample.reshape(28, 28)\n",
    "    fig, ax = plt.subplots()\n",
    "    ax.imshow(image, cmap='gray', interpolation='nearest')\n",
    "\n",
    "    if fixed_indices is not None:\n",
    "        for idx in fixed_indices:\n",
    "            row, col = divmod(idx, 28)\n",
    "            rect = patches.Rectangle(\n",
    "                (col - 0.5, row - 0.5), 1, 1,\n",
    "                linewidth=1.5,\n",
    "                edgecolor='red',\n",
    "                facecolor='none'\n",
    "            )\n",
    "            ax.add_patch(rect)\n",
    "\n",
    "    ax.axis('off')\n",
    "    plt.tight_layout()\n",
    "    if save_as:\n",
    "        plt.savefig(save_as)\n",
    "    plt.show()\n",
    "\n",
    "def visualize_fixed_positions(fixed_indices, fixed_values):\n",
    "    \"\"\"Heatmap showing fixed positions as white (1) or black (0) with light red background for unfixed.\"\"\"\n",
    "    mask = np.full(BIT_LENGTH, np.nan)  # NaN for non-fixed bits\n",
    "    mask[fixed_indices] = fixed_values  # Set fixed bits to 0 or 1\n",
    "    mask = mask.reshape(28, 28)\n",
    "\n",
    "    cmap = plt.cm.gray.copy()\n",
    "    cmap.set_bad(color='#ffd9d9')  # Light red for non-fixed bits\n",
    "\n",
    "    plt.imshow(mask, cmap=cmap, vmin=0, vmax=1)\n",
    "    plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"synthetic_pattern.png\")\n",
    "    plt.show()\n",
    "\n",
    "def bit_frequency_map(samples):\n",
    "    \"\"\"Shows how often each bit is set to 1 in a class.\"\"\"\n",
    "    freq = samples.mean(axis=0).reshape(28, 28)\n",
    "    im = plt.imshow(freq, cmap='viridis', interpolation='nearest')\n",
    "    cbar = plt.colorbar(im)\n",
    "    cbar.set_label('P(1)', fontsize=26)\n",
    "    cbar.ax.tick_params(labelsize=24)\n",
    "    plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"synthetic_frequency.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "def tsne_plot(data, labels):\n",
    "    \"\"\"Visualizes global class structure via t-SNE.\"\"\"\n",
    "    print(\"Running t-SNE... (this may take a moment)\")\n",
    "    tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, init='random', random_state=42)\n",
    "    data_2d = tsne.fit_transform(data)\n",
    "\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    sns.scatterplot(x=data_2d[:, 0], y=data_2d[:, 1], hue=labels, palette='tab10', s=10, linewidth=0)\n",
    "    plt.axis('off')\n",
    "    plt.legend(title='Class', bbox_to_anchor=(1.05, 1), loc='upper left', ncol=1)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "# ----------- Save 10 Class Samples as PNG Images -----------\n",
    "def save_class_sample_sequence(class_samples, fixed_indices, class_id, output_dir=\"gif_frames\"):\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "    for i in range(10):\n",
    "        filename = os.path.join(output_dir, f\"sample_{i:02d}.png\")\n",
    "        visualize_class_sample(\n",
    "            class_samples[i], \n",
    "            class_id=class_id, \n",
    "            fixed_indices=fixed_indices, \n",
    "            save_as=filename\n",
    "        )\n",
    "        print(f\"Saved: {filename}\")\n",
    "\n",
    "# ----------- Run Everything -----------\n",
    "if __name__ == \"__main__\":\n",
    "    # Generate data\n",
    "    data, labels, fixed_bits_info = generate_synthetic_data()\n",
    "\n",
    "    # Choose a class to visualize\n",
    "    class_id = 0\n",
    "    class_samples = data[labels == class_id]\n",
    "    fixed_indices, fixed_values = fixed_bits_info[class_id]\n",
    "\n",
    "    save_class_sample_sequence(class_samples, fixed_indices, class_id)\n",
    "\n",
    "    # Paper-relevant plots\n",
    "    \"\"\"\n",
    "    visualize_class_sample(class_samples[0], class_id, fixed_indices, save_as=\"synthetic_sample_0.pdf\")\n",
    "    visualize_class_sample(class_samples[1], class_id, fixed_indices, save_as=\"synthetic_sample_1.pdf\")\n",
    "    visualize_class_sample(class_samples[2], class_id, fixed_indices, save_as=\"synthetic_sample_2.pdf\")\n",
    "    \"\"\"\n",
    "    visualize_fixed_positions(fixed_indices, fixed_values)\n",
    "    # bit_frequency_map(class_samples)\n",
    "    # tsne_plot(data, labels)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fa39971-0129-43cd-929a-e5b66de34524",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create GIF\n",
    "from PIL import Image\n",
    "import glob\n",
    "\n",
    "def create_gif_from_images(frame_dir=\"gif_frames\", output_file=\"synthetic_animation.gif\", duration=1000):\n",
    "    # Load all .png files sorted by name\n",
    "    frames = [Image.open(img) for img in sorted(glob.glob(f\"{frame_dir}/*.png\"))]\n",
    "\n",
    "    # Save as .gif\n",
    "    frames[0].save(\n",
    "        output_file,\n",
    "        save_all=True,\n",
    "        append_images=frames[1:],\n",
    "        duration=duration,  # milliseconds per frame\n",
    "        loop=0  # infinite loop\n",
    "    )\n",
    "    print(f\"GIF saved to: {output_file}\")\n",
    "\n",
    "# Example call\n",
    "create_gif_from_images()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
