{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "sUWQedGgAZ9Q",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "# @title Block 1: Infrastructure\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import numpy as np\n",
        "import math\n",
        "import copy\n",
        "import matplotlib.pyplot as plt\n",
        "import pandas as pd\n",
        "import random\n",
        "import itertools\n",
        "import gc\n",
        "import os\n",
        "import seaborn as sns\n",
        "from datetime import datetime\n",
        "from matplotlib.colors import LinearSegmentedColormap\n",
        "from mpl_toolkits.mplot3d import Axes3D\n",
        "\n",
        "!pip install sympy --upgrade\n",
        "import sympy\n",
        "print(f\"SymPy version: {sympy.__version__}\")\n",
        "\n",
        "for i in mlnx:\n",
        "  i = i + 1\n",
        "\n",
        "def phi(t):\n",
        "    \"\"\"\n",
        "    Activation function that transforms sums into parities.\n",
        "    phi(0) = -1, phi(+/-1) = 1, phi'(0) = phi'(+/-1) = 0\n",
        "    \"\"\"\n",
        "    u = 128\n",
        "    result = torch.ones_like(t)\n",
        "\n",
        "    mask1 = (t > -1) & (t < -(u ** -3))\n",
        "    result[mask1] = -2 * t[mask1] - 1\n",
        "    mask2 = (t >= -(u ** -3)) & (t < u ** -3)\n",
        "    result[mask2] = (u ** 3) * (t[mask2] ** 2) + (u ** -3) - 1\n",
        "    mask3 = (t >= u ** -3) & (t < 1)\n",
        "    result[mask3] = 2 * t[mask3] - 1\n",
        "\n",
        "    return result\n",
        "\n",
        "def apply_segmented_flips(label_tensor, flip_config):\n",
        "    \"\"\"\n",
        "    Apply flips to specific segments at specific positions\n",
        "\n",
        "    flip_config: list of tuples (position_idx, sample_indices)\n",
        "    sample_indices can be:\n",
        "      - tuple (start_idx, end_idx) for a continuous range\n",
        "      - list of specific indices\n",
        "      - string \"random:N\" to flip N random samples\n",
        "      - string \"pattern:stride:offset\" for strided selection\n",
        "    \"\"\"\n",
        "    noisy_label = label_tensor.clone()\n",
        "    total_flips = 0\n",
        "\n",
        "    for item in flip_config:\n",
        "        pos_idx = item[0]\n",
        "        sample_spec = item[1]\n",
        "\n",
        "        # Handle different ways of specifying samples\n",
        "        if isinstance(sample_spec, tuple) and len(sample_spec) == 2:\n",
        "            # Continuous range (start_idx, end_idx)\n",
        "            start_idx, end_idx = sample_spec\n",
        "            indices = list(range(start_idx, end_idx))\n",
        "            print(f\"Flipping position {pos_idx} for samples {start_idx}-{end_idx}\")\n",
        "\n",
        "        elif isinstance(sample_spec, list):\n",
        "            # Explicit list of indices\n",
        "            indices = sample_spec\n",
        "            print(f\"Flipping position {pos_idx} for {len(indices)} specific samples\")\n",
        "\n",
        "        elif isinstance(sample_spec, str) and sample_spec.startswith(\"random:\"):\n",
        "            # Random selection of samples\n",
        "            n_samples = int(sample_spec.split(':')[1])\n",
        "            indices = torch.randperm(len(noisy_label))[:n_samples].tolist()\n",
        "            print(f\"Flipping position {pos_idx} for {n_samples} random samples\")\n",
        "\n",
        "        elif isinstance(sample_spec, str) and sample_spec.startswith(\"pattern:\"):\n",
        "            # Pattern-based selection (e.g., every 3rd sample starting from index 5)\n",
        "            parts = sample_spec.split(':')\n",
        "            stride = int(parts[1])\n",
        "            offset = int(parts[2]) if len(parts) > 2 else 0\n",
        "            indices = list(range(offset, len(noisy_label), stride))\n",
        "            print(f\"Flipping position {pos_idx} with pattern (stride={stride}, offset={offset})\")\n",
        "\n",
        "        else:\n",
        "            raise ValueError(f\"Unsupported sample specification: {sample_spec}\")\n",
        "\n",
        "        # Apply flips to the specified indices\n",
        "        noisy_label[indices, pos_idx] = -noisy_label[indices, pos_idx]\n",
        "        total_flips += len(indices)\n",
        "\n",
        "    print(f\"Total flips: {total_flips} out of {label_tensor.shape[0] * label_tensor.shape[1]} values\")\n",
        "    return noisy_label\n",
        "\n",
        "def generate_data(n, d, k, target, uniform_prob=1.0):\n",
        "    \"\"\"\n",
        "    Generate data with controlled uniformity in target positions\n",
        "\n",
        "    Parameters:\n",
        "    n: number of samples\n",
        "    d: input dimension\n",
        "    k: subset size\n",
        "    target: indices of relevant features\n",
        "    uniform_prob: probability of uniform random values for relevant bits\n",
        "\n",
        "    Returns:\n",
        "    dat: generated data tensor\n",
        "    use_uniform: boolean mask of which samples use uniform random values\n",
        "    target_values: values used for non-uniform samples\n",
        "    \"\"\"\n",
        "    # Generate random binary (-1 or 1) input data\n",
        "    dat = torch.randint(0, 2, (n, d), dtype=torch.float32) * 2 - 1\n",
        "\n",
        "    # For each sample, determine if we'll use uniform random or all-same values for target positions\n",
        "    use_uniform = torch.rand(n) < uniform_prob  # Samples that use uniform random values\n",
        "\n",
        "    # For samples that will have all-same values:\n",
        "    # Generate a random value (1 or -1) for each sample\n",
        "    target_values = torch.randint(0, 2, (n,), dtype=torch.float32) * 2 - 1\n",
        "\n",
        "    # Apply the generation rules\n",
        "    for i in range(n):\n",
        "        if not use_uniform[i]:\n",
        "            # Set all target positions to the same value\n",
        "            dat[i, target] = target_values[i]\n",
        "\n",
        "    return dat, use_uniform, target_values"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "5tLvx_MhAd9L"
      },
      "outputs": [],
      "source": [
        "# @title Block 2: Hyperparameters\n",
        "import time\n",
        "from datetime import datetime\n",
        "\n",
        "# HYPERPARAMETERS\n",
        "n = 400000  # sample size\n",
        "d = 128  # input dimension\n",
        "k = 64  # subset size\n",
        "lr = 50  # learning rate for teacher forcing\n",
        "\n",
        "# Training parameters\n",
        "max_epochs = 5000  # You may want to reduce this for quicker experiments\n",
        "x_values = range(1, max_epochs + 1)\n",
        "epoch_ceil = 100\n",
        "\n",
        "# Dynamic seeding function\n",
        "def set_random_seed():\n",
        "    \"\"\"Set different random seeds for each run based on current time\"\"\"\n",
        "    # Get current time in milliseconds for a dynamic seed\n",
        "    current_time = int(time.time() * 1000)\n",
        "    seed = current_time % 100000  # Keep it a reasonable size\n",
        "\n",
        "    # Set all relevant random seeds\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "    # Ensure torch operations are non-deterministic for maximum randomness\n",
        "    torch.backends.cudnn.deterministic = False\n",
        "    torch.backends.cudnn.benchmark = True\n",
        "\n",
        "    # For even more randomness, set Python hash seed\n",
        "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
        "\n",
        "    return seed\n",
        "\n",
        "# Set dynamic random seed\n",
        "current_seed = set_random_seed()\n",
        "\n",
        "# Set device\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "# Select random target indices\n",
        "target = random.sample(range(d), k)\n",
        "\n",
        "# Create weight matrix with appropriate constraint mask\n",
        "w = torch.zeros(d+k-1, d+k-1).to(device)\n",
        "for m in range(1, d+1):\n",
        "    w[:, m-1] = -np.inf  # Mask out input columns\n",
        "for m in range(d+1, d+k):\n",
        "    w[(m-1):, m-1] = -np.inf  # Mask out lower triangle\n",
        "\n",
        "class TF(nn.Module):\n",
        "    \"\"\"\n",
        "    Transformer model for parity learning with teacher forcing\n",
        "    \"\"\"\n",
        "    def __init__(self):\n",
        "        super(TF, self).__init__()\n",
        "        self.w = nn.Parameter(copy.deepcopy(w))\n",
        "\n",
        "        # Add small random noise to initialization for better randomness\n",
        "        with torch.no_grad():\n",
        "            noise = torch.randn_like(self.w) * 1e-4\n",
        "            mask = (self.w != -float('inf'))  # Don't add noise to masked positions\n",
        "            self.w.data = torch.where(mask, self.w + noise, self.w)\n",
        "\n",
        "    def forward(self, x):\n",
        "        softmax = nn.Softmax(dim=0)\n",
        "        return torch.cat((x[:, :d], phi(torch.matmul(x, softmax(self.w)[:, d:]))), dim=1)\n",
        "\n",
        "# Function to reset model with fresh random initialization\n",
        "def get_fresh_model():\n",
        "    \"\"\"Create a new model instance with random initialization\"\"\"\n",
        "    model = TF().to(device)\n",
        "    return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "sSw2YZ46Aqqs"
      },
      "outputs": [],
      "source": [
        "# @title Block 3: Test size and training algorithm\n",
        "def train_and_evaluate(flip_config_dict, uniform_prob, max_epochs=2000, verbose=False):\n",
        "    \"\"\"\n",
        "    Train model with specified parameters and evaluate performance\n",
        "\n",
        "    Parameters:\n",
        "    flip_config_dict: Dictionary mapping positions to their flip specifications\n",
        "                     {position1: flip_spec1, position2: flip_spec2, ...}\n",
        "                     where flip_spec can be:\n",
        "                     - float: flip rate (0.0-1.0) starting from index 0\n",
        "                     - tuple (start_idx, end_idx): specific continuous range\n",
        "                     - list: specific indices to flip\n",
        "                     - \"random:N\": N random samples\n",
        "                     - \"pattern:stride:offset\": every stride-th sample starting at offset\n",
        "    uniform_prob: Probability of uniform random values for relevant bits\n",
        "    max_epochs: Maximum training epochs\n",
        "    verbose: Whether to print detailed progress\n",
        "\n",
        "    Returns:\n",
        "    test_loss: Mean squared error on test set\n",
        "    test_accuracy: Accuracy on test set\n",
        "    final_train_loss: Loss at the end of training\n",
        "    \"\"\"\n",
        "    # Generate data with specified uniformity\n",
        "    dat, use_uniform, target_values = generate_data(n, d, k, target, uniform_prob)\n",
        "\n",
        "    # Initialize intermediate labels\n",
        "    label = torch.zeros(n, k-1)\n",
        "    temp = copy.deepcopy(dat[:, target])\n",
        "\n",
        "    # Generate intermediate computation steps\n",
        "    for m in range(k-1):\n",
        "        label[:, m] = temp[:, 0] * temp[:, 1]\n",
        "        temp = torch.cat((temp[:, 2:], label[:, m].unsqueeze(1)), dim=1)\n",
        "\n",
        "    # Calculate final target\n",
        "    y = torch.prod(dat[:, target], dim=1)\n",
        "\n",
        "    # Move data to device\n",
        "    dat = dat.to(device)\n",
        "    label = label.to(device)\n",
        "    y = y.to(device)\n",
        "\n",
        "    # Create flip configuration for multiple positions\n",
        "    flip_config = []\n",
        "    for position, flip_spec in flip_config_dict.items():\n",
        "        # Convert rate-based specifications to range tuples\n",
        "        if isinstance(flip_spec, float):\n",
        "            n_flip = int(flip_spec * n)\n",
        "            if n_flip > 0:  # Only add if there are actual flips\n",
        "                flip_config.append((position, (0, n_flip)))\n",
        "        else:\n",
        "            # Pass other specifications directly\n",
        "            flip_config.append((position, flip_spec))\n",
        "\n",
        "    # Apply flips to create noisy labels\n",
        "    noisy_label = apply_segmented_flips(label.clone(), flip_config)\n",
        "\n",
        "    # Create tensors for training\n",
        "    X = torch.cat((dat, torch.zeros((n, k-1), dtype=torch.float32, device=device)), dim=1)\n",
        "    X_true = torch.cat((dat, noisy_label), dim=1)\n",
        "\n",
        "    # Create and initialize model\n",
        "    TF_tf = TF().to(device)\n",
        "    TF_tf_optim = optim.SGD(TF_tf.parameters(), lr=lr)\n",
        "\n",
        "    # Initialize tensor to store training loss\n",
        "    list_tf = torch.zeros(max_epochs)\n",
        "\n",
        "    # Training loop\n",
        "    for epoch in range(max_epochs):\n",
        "        # Zero gradients\n",
        "        TF_tf_optim.zero_grad()\n",
        "\n",
        "        # Compute loss against noisy labels\n",
        "        loss_tf = torch.mean((noisy_label - TF_tf(X_true)[:, d:]) ** 2) / 2\n",
        "\n",
        "        # Backward pass\n",
        "        loss_tf.backward()\n",
        "\n",
        "        # Update model parameters\n",
        "        TF_tf_optim.step()\n",
        "\n",
        "        # Record training loss\n",
        "        list_tf[epoch] = loss_tf\n",
        "\n",
        "    # Generate test data with same uniformity\n",
        "    test_size = 100000\n",
        "    test_dat, _, _ = generate_data(test_size, d, k, target, uniform_prob=1.0)\n",
        "\n",
        "    # Calculate test labels\n",
        "    test_label = torch.zeros(test_size, k-1)\n",
        "    test_temp = copy.deepcopy(test_dat[:, target])\n",
        "\n",
        "    for m in range(k-1):\n",
        "        test_label[:, m] = test_temp[:, 0] * test_temp[:, 1]\n",
        "        test_temp = torch.cat((test_temp[:, 2:], test_label[:, m].unsqueeze(1)), dim=1)\n",
        "\n",
        "    # Calculate final target for test data\n",
        "    test_y = torch.prod(test_dat[:, target], dim=1)\n",
        "\n",
        "    # Move test data to device\n",
        "    test_dat = test_dat.to(device)\n",
        "    test_label = test_label.to(device)\n",
        "    test_y = test_y.to(device)\n",
        "\n",
        "    # Create input tensor for test data\n",
        "    test_X = torch.cat((test_dat, torch.zeros((test_size, k-1), dtype=torch.float32, device=device)), dim=1)\n",
        "\n",
        "    # Evaluate on test data\n",
        "    TF_tf.eval()\n",
        "\n",
        "    # Forward pass through model with test data\n",
        "    test_X_eval = copy.deepcopy(test_X)\n",
        "    for m in range(k-1):\n",
        "        test_X_eval = TF_tf(test_X_eval).detach()\n",
        "\n",
        "    # Calculate test metrics\n",
        "    test_loss = torch.mean((test_y - test_X_eval[:, -1]) ** 2).detach() / 2\n",
        "    test_accuracy = torch.mean((abs(test_X_eval[:, -1] - test_y) <= 0.5).type(torch.float32)).item()\n",
        "\n",
        "    if verbose:\n",
        "        # Create a string representation of the flip configuration\n",
        "        flip_str = []\n",
        "        for pos, spec in flip_config_dict.items():\n",
        "            if isinstance(spec, float):\n",
        "                flip_str.append(f\"Pos {pos}:{spec*100:.1f}%\")\n",
        "            else:\n",
        "                flip_str.append(f\"Pos {pos}:{spec}\")\n",
        "        flip_info = \", \".join(flip_str)\n",
        "        print(f\"Uniformity: {uniform_prob}, Flips: {flip_info}\")\n",
        "        print(f\"Test loss: {test_loss:.4f}, Accuracy: {test_accuracy:.4f}\")\n",
        "\n",
        "    # Return metrics\n",
        "    return test_loss.item(), test_accuracy, list_tf[-1].item()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "9KBf8934Atbl"
      },
      "outputs": [],
      "source": [
        "# @title Block 4: Configurations\n",
        "# Define flip configurations to test\n",
        "flip_configurations = [\n",
        "    {'name': 'No corruption',\n",
        "     'config': {\n",
        "        0: (0, 0),\n",
        "        1: (0, 0)\n",
        "    }},\n",
        "    {'name': 'Steps 1 and 2 have same flips (zero impactful)',\n",
        "     'config': {\n",
        "        0: (0, int(0.2*n)),\n",
        "        1: (0, int(0.2*n))\n",
        "    }},\n",
        "    {'name': 'Steps 1 and 2, (20% impactful)',\n",
        "     'config': {\n",
        "        0: (0, int(0.1*n)),\n",
        "        1: (int(0.1*n), int(0.2*n))\n",
        "    }},\n",
        "    {'name': 'Steps 1 and 2, (40% impactful)',\n",
        "     'config': {\n",
        "        0: (0, int(0.2*n)),\n",
        "        1: (int(0.2*n), int(0.4*n))\n",
        "    }},\n",
        "    {'name': 'Steps 33, 34 and 49 have same flips (25% impactful)',\n",
        "     'config': {\n",
        "        32: (0, int(0.25*n)),\n",
        "        33: (0, int(0.25*n)),\n",
        "        48: (0, int(0.25*n))\n",
        "    }},\n",
        "    {'name': 'Steps 33, 34 and 49 (20% impactful, 33 and 34 identical)',\n",
        "     'config': {\n",
        "        32: (0, int(0.25*n)),\n",
        "        33: (0, int(0.25*n)),\n",
        "        48: (int(0.5*n), int(0.7*n))\n",
        "    }},\n",
        "    {'name': 'Steps 33, 34 and 49 (20% impactful, 33 and 49 identical)',\n",
        "     'config': {\n",
        "        32: (0, int(0.25*n)),\n",
        "        33: (int(0.5*n), int(0.7*n)),\n",
        "        48: (0, int(0.25*n))\n",
        "    }}\n",
        "]\n",
        "\n",
        "\n",
        "# Different uniformity probabilities to test\n",
        "uniform_probs = [0, 0.25, 0.5, 0.75, 1.0]\n",
        "\n",
        "def generate_experiment_groups(configurations):\n",
        "    # Group configurations by the positions they flip\n",
        "    groups = {}\n",
        "\n",
        "    for config in configurations:\n",
        "        # Get the sorted positions being flipped\n",
        "        positions = tuple(sorted(config['config'].keys()))\n",
        "\n",
        "        # Use these positions as the group name\n",
        "        group_name = f\"Positions_{'-'.join(map(str, positions))}\"\n",
        "\n",
        "        # Add this configuration to the appropriate group\n",
        "        if group_name not in groups:\n",
        "            groups[group_name] = []\n",
        "\n",
        "        groups[group_name].append(config['name'])\n",
        "\n",
        "    return groups\n",
        "\n",
        "# Generate experiment_groups automatically\n",
        "experiment_groups = generate_experiment_groups(flip_configurations)\n",
        "\n",
        "def verify_flip_implementation(flip_config_item):\n",
        "    \"\"\"\n",
        "    Verify that the flipping implementation works correctly with your existing configuration\n",
        "\n",
        "    Parameters:\n",
        "    flip_config_item: One item from your flip_configurations list (contains 'name' and 'config')\n",
        "    \"\"\"\n",
        "    config_name = flip_config_item['name']\n",
        "    config = flip_config_item['config']\n",
        "\n",
        "    # Generate training data exactly as in your training process\n",
        "    dat, use_uniform, target_values = generate_data(n=n, d=d, k=k, target=target, uniform_prob=1.0)\n",
        "\n",
        "    # Calculate intermediate labels exactly as in training\n",
        "    label = torch.zeros(n, k-1)\n",
        "    temp = copy.deepcopy(dat[:, target])\n",
        "\n",
        "    for m in range(k-1):\n",
        "        label[:, m] = temp[:, 0] * temp[:, 1]\n",
        "        temp = torch.cat((temp[:, 2:], label[:, m].unsqueeze(1)), dim=1)\n",
        "\n",
        "    # Convert to the format expected by apply_segmented_flips\n",
        "    formatted_config = []\n",
        "    for pos, spec in config.items():\n",
        "        if pos < k-1:  # Only include valid positions\n",
        "            formatted_config.append((pos, spec))\n",
        "\n",
        "    if not formatted_config:\n",
        "        print(\"\\n❌ No valid positions to flip. All positions are out of bounds.\")\n",
        "        return\n",
        "\n",
        "    # Apply the flips\n",
        "    import time\n",
        "    start_time = time.time()\n",
        "    flipped_label = apply_segmented_flips(label.clone(), formatted_config)\n",
        "    elapsed = time.time() - start_time\n",
        "\n",
        "    # Check results\n",
        "    total_flipped = torch.sum(label != flipped_label).item()\n",
        "\n",
        "    # Check each position\n",
        "    for pos, spec in config.items():\n",
        "        if pos >= k-1:\n",
        "            continue  # Skip invalid positions\n",
        "\n",
        "        flips_at_pos = torch.sum(label[:, pos] != flipped_label[:, pos]).item()\n",
        "        expected = spec[1] - spec[0] if isinstance(spec, tuple) else 0\n",
        "\n",
        "        if flips_at_pos == expected:\n",
        "            print(f\"  ✓ Correct number of flips\")\n",
        "        else:\n",
        "            print(f\"  ✗ Incorrect number of flips (expected {expected}, got {flips_at_pos})\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": true,
        "id": "kmo_5aUrJSGS",
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "# @title Block 5: Experiment details\n",
        "def run_experiments(configurations, uniform_probs):\n",
        "    \"\"\"\n",
        "    Run all experiments and save results\n",
        "    \"\"\"\n",
        "    # Initialize results dataframe\n",
        "    results = []\n",
        "\n",
        "    # Run experiments\n",
        "    for uniform_prob in uniform_probs:\n",
        "        for exp in configurations:\n",
        "            exp_name = exp['name']\n",
        "            flip_config = exp['config']\n",
        "\n",
        "            print(f\"\\n=== Starting experiment '{exp_name}' with uniformity={uniform_prob} ===\\n\")\n",
        "            print(f\"Flip configuration: {flip_config}\")\n",
        "\n",
        "            try:\n",
        "                loss, accuracy, final_train_loss = train_and_evaluate(\n",
        "                    flip_config, uniform_prob, max_epochs=max_epochs, verbose=True\n",
        "                )\n",
        "\n",
        "                # Calculate total corruption\n",
        "                total_corruption = 0\n",
        "                for pos, spec in flip_config.items():\n",
        "                    if isinstance(spec, float):\n",
        "                        total_corruption += spec\n",
        "                    elif isinstance(spec, tuple) and len(spec) == 2:\n",
        "                        total_corruption += (spec[1] - spec[0]) / n\n",
        "                    elif isinstance(spec, str) and spec.startswith(\"random:\"):\n",
        "                        total_corruption += int(spec.split(':')[1]) / n\n",
        "                    elif isinstance(spec, str) and spec.startswith(\"pattern:\"):\n",
        "                        parts = spec.split(':')\n",
        "                        stride = int(parts[1])\n",
        "                        total_corruption += 1/stride\n",
        "\n",
        "                # Save result\n",
        "                results.append({\n",
        "                    'Uniform_Prob': uniform_prob,\n",
        "                    'Experiment': exp_name,\n",
        "                    'Flip_Config': str(flip_config),\n",
        "                    'Test_Loss': loss,\n",
        "                    'Test_Accuracy': accuracy,\n",
        "                    'Final_Train_Loss': final_train_loss,\n",
        "                    'Total_Corruption': total_corruption\n",
        "                })\n",
        "\n",
        "                # Save intermediate results to CSV\n",
        "                results_df = pd.DataFrame(results)\n",
        "\n",
        "            except Exception as e:\n",
        "                print(f\"Error in experiment {exp_name}: {e}\")\n",
        "\n",
        "            # Clear memory\n",
        "            torch.cuda.empty_cache()\n",
        "            gc.collect()\n",
        "\n",
        "    return pd.DataFrame(results)\n",
        "\n",
        "# Run the experiments with configurations and uniform_probs from previous block\n",
        "results_df = run_experiments(flip_configurations, uniform_probs)\n",
        "\n",
        "# Print summary of results\n",
        "print(\"\\n=== Experiment Results ===\")\n",
        "print(results_df)\n",
        "\n",
        "# Plot and save the results\n",
        "def plot_results(results_df, plots_dir):\n",
        "    \"\"\"\n",
        "    Create and save visualization of experiment results\n",
        "    \"\"\"\n",
        "    if len(results_df) == 0:\n",
        "        print(\"No results to plot\")\n",
        "        return\n",
        "\n",
        "    # Set plot style\n",
        "    plt.style.use('seaborn-v0_8-whitegrid')\n",
        "\n",
        "    # Create a heatmap to visualize experiment × uniformity performance\n",
        "    if len(results_df['Experiment'].unique()) > 1 and len(results_df['Uniform_Prob'].unique()) > 1:\n",
        "        # Prepare data for heatmap\n",
        "        pivot_loss = results_df.pivot_table(\n",
        "            values='Test_Loss',\n",
        "            index='Experiment',\n",
        "            columns='Uniform_Prob'\n",
        "        )\n",
        "\n",
        "        # Create figure and axes for more control\n",
        "        fig, ax = plt.subplots(figsize=(20, 12))  # Wider figure\n",
        "\n",
        "        # Generate heatmap with larger numbers\n",
        "        heatmap = sns.heatmap(pivot_loss, annot=True, cmap='YlOrRd', fmt='.3f', linewidths=.5,\n",
        "                  annot_kws={\"size\": 26}, ax=ax, cbar_kws={'shrink': 0.8})  # Larger numbers in grid cells\n",
        "\n",
        "        # Set title and labels with bold text\n",
        "        ax.set_title('Test Loss by Experiment and Uniformity', fontsize=26, fontweight='bold')\n",
        "        ax.set_xlabel('Uniformity Ratio', fontsize=24, fontweight='bold')\n",
        "        ax.set_ylabel('Corruption', fontsize=24, fontweight='bold')\n",
        "\n",
        "        # Make tick labels bold\n",
        "        plt.setp(ax.get_xticklabels(), fontsize=22, fontweight='bold')\n",
        "        plt.setp(ax.get_yticklabels(), fontsize=22, fontweight='bold')\n",
        "\n",
        "        # Increase colorbar text size and make it bold\n",
        "        cbar = heatmap.collections[0].colorbar\n",
        "        cbar.ax.tick_params(labelsize=22)  # Match the size of other tick labels\n",
        "        plt.setp(cbar.ax.get_yticklabels(), fontweight='bold')  # Make colorbar labels bold\n",
        "\n",
        "        plt.tight_layout()\n",
        "        plt.savefig(f\"{plots_dir}/heatmap_loss.png\", dpi=300)\n",
        "        plt.close()\n",
        "\n",
        "        print(f\"Test loss heatmap has been saved to {plots_dir}/heatmap_loss.png\")\n",
        "    else:\n",
        "        print(\"Not enough data to create heatmap (need multiple experiments and uniformity values)\")\n",
        "\n",
        "# Create the plots\n",
        "try:\n",
        "    plot_results(results_df, plots_dir)\n",
        "except Exception as e:\n",
        "    print(f\"Error creating plots: {e}\")\n",
        "    import traceback\n",
        "    traceback.print_exc()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "L4",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}