{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "L4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jTbf96rEQ2T6"
      },
      "outputs": [],
      "source": [
        "# from google.colab import drive\n",
        "# drive.mount('/content/drive')"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import Subset, DataLoader\n",
        "import sys\n",
        "import os\n",
        "import pickle\n",
        "\n",
        "PROJECT_PATH = '/content/drive/MyDrive/tfc-sr'\n",
        "sys.path.append(PROJECT_PATH)"
      ],
      "metadata": {
        "id": "xPO2xDArRQUi"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(f\"Using device: {device}\")"
      ],
      "metadata": {
        "id": "5G16HAJnXDip"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "CONFIG = {\n",
        "    'seed': 42,\n",
        "    'num_tasks': 10,\n",
        "    'epochs_per_task': 20,\n",
        "    'batch_size': 64,\n",
        "    'lr': 0.001,\n",
        "    'num_classes': 100,\n",
        "    'results_path': os.path.join(PROJECT_PATH, 'results'),\n",
        "    'checkpoints_path': os.path.join(PROJECT_PATH, 'checkpoints'),\n",
        "}"
      ],
      "metadata": {
        "id": "bHztDPVbXHd0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install avalanche-lib\n",
        "\n",
        "from avalanche.benchmarks.classic import SplitMNIST\n",
        "from avalanche.training import EWC, SynapticIntelligence\n",
        "from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, forgetting_metrics\n",
        "from avalanche.logging import InteractiveLogger, TextLogger\n",
        "from avalanche.training.plugins import EvaluationPlugin\n"
      ],
      "metadata": {
        "id": "RQHyOWyiZIgY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip freeze > requirements.txt"
      ],
      "metadata": {
        "id": "U1v2P1xGfldm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from data_setup import get_split_mnist_dataloaders\n",
        "from utils import set_seed, save_results, plot_results, load_results, evaluate_on_seen_tasks, ReservoirReplayBuffer\n",
        "\n",
        "set_seed(CONFIG['seed'])"
      ],
      "metadata": {
        "id": "W_CEre6jYuSn"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from avalanche.benchmarks.classic import SplitCIFAR100\n",
        "\n",
        "# Create a benchmark with 10 tasks, each containing 10 new classes.\n",
        "split_cifar100_benchmark = SplitCIFAR100(n_experiences=10, seed=CONFIG['seed'])"
      ],
      "metadata": {
        "id": "WJEtcRDxYw8z"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from torchvision.models import resnet18\n",
        "\n",
        "def get_resnet18_for_cifar(num_classes=100):\n",
        "    \"\"\"\n",
        "    Returns a ResNet-18 model adapted for the CIFAR dataset.\n",
        "    \"\"\"\n",
        "    model = resnet18(weights=None) # weights=None means training from scratch\n",
        "\n",
        "    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)\n",
        "    model.maxpool = nn.Identity() # Remove the initial max pooling\n",
        "\n",
        "    # The final layer needs to be replaced for the correct number of classes\n",
        "    num_ftrs = model.fc.in_features\n",
        "    model.fc = nn.Linear(num_ftrs, num_classes)\n",
        "\n",
        "    return model"
      ],
      "metadata": {
        "id": "9ygzlda5Z9Os"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: BASELINE ON SPLIT CIFAR-100 (WITH LR TUNING) ---\n",
        "import torchvision.transforms as transforms\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting Baseline Experiment on Split CIFAR-100 \" + \"=\"*20)\n",
        "\n",
        "# --- Hyperparameter Search Setup ---\n",
        "learning_rates_to_try = [0.01, 0.001, 0.0001]\n",
        "all_baseline_results = {}\n",
        "\n",
        "# --- Define the benchmark ONCE ---\n",
        "# This ensures all LR trials use the exact same data splits and order.\n",
        "split_cifar100_benchmark = SplitCIFAR100(\n",
        "    n_experiences=10,\n",
        "    seed=CONFIG['seed'],\n",
        "    # Standard normalization for CIFAR-100\n",
        "    train_transform=transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\n",
        "    ]),\n",
        "    eval_transform=transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))\n",
        "    ])\n",
        ")\n",
        "\n",
        "for lr in learning_rates_to_try:\n",
        "    print(f\"\\n--- Running Baseline with learning rate = {lr} ---\")\n",
        "\n",
        "    set_seed(CONFIG['seed']) # Reset seed for each trial for fairness\n",
        "\n",
        "    model_baseline = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "    optimizer = optim.Adam(model_baseline.parameters(), lr=lr)\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "    current_lr_accuracies = []\n",
        "\n",
        "    # --- Main Continual Learning Loop ---\n",
        "    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "        print(f\"--> Training on Task {task_id+1}/{len(split_cifar100_benchmark.train_stream)}\")\n",
        "\n",
        "        train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "        # --- Training loop ---\n",
        "        model_baseline.train()\n",
        "        for epoch in range(CONFIG['epochs_per_task']):\n",
        "            running_loss = 0.0\n",
        "            for data, targets, _ in train_loader:\n",
        "                data, targets = data.to(device), targets.to(device)\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "                outputs = model_baseline(data)\n",
        "                loss = criterion(outputs, targets)\n",
        "                loss.backward()\n",
        "                optimizer.step()\n",
        "                running_loss += loss.item()\n",
        "\n",
        "            print(f\"Task {task_id+1}, Epoch {epoch+1}, Avg Loss: {running_loss / len(train_loader):.4f}\")\n",
        "\n",
        "        # --- Evaluation Step ---\n",
        "        accuracy = evaluate_on_seen_tasks(model_baseline, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "        current_lr_accuracies.append(accuracy)\n",
        "        print(f\"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "    all_baseline_results[lr] = current_lr_accuracies\n",
        "\n",
        "# --- Find the best learning rate and save its results ---\n",
        "# We choose the LR that gave the best accuracy after the final task\n",
        "best_lr = max(all_baseline_results, key=lambda k: all_baseline_results[k][-1])\n",
        "best_baseline_accuracies = all_baseline_results[best_lr]\n",
        "\n",
        "print(f\"\\nBest Baseline learning rate was {best_lr} with final accuracy: {best_baseline_accuracies[-1]:.2f}%\")\n",
        "\n",
        "# Save only the results from the BEST run\n",
        "baseline_cifar_results_path = os.path.join(CONFIG['results_path'], 'baseline_cifar_accuracies.pkl')\n",
        "save_results(best_baseline_accuracies, baseline_cifar_results_path)\n",
        "\n",
        "# --- Plot all the trial runs to visualize the tuning process ---\n",
        "plot_results({f'LR={lr}': acc for lr, acc in all_baseline_results.items()},\n",
        "             title=\"Baseline LR Tuning on Split CIFAR-100\")\n",
        "\n",
        "# --- Plot just the best result ---\n",
        "plot_results({'Baseline (Best LR)': best_baseline_accuracies},\n",
        "             title=\"Best Baseline Performance on Split CIFAR-100\")"
      ],
      "metadata": {
        "id": "hv8YVhorbUJt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Best Baseline learning rate was 0.001 with final accuracy: 7.27%"
      ],
      "metadata": {
        "id": "TYuQrJlxcx4v"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# doing the experiment again to save better results data structure which is now a dictionary. Above experiment already did fine tuning for learning rate.\n",
        "\n",
        "# --- EXPERIMENT 1: BASELINE (SEQUENTIAL FINE-TUNING) ON SPLIT CIFAR-100 ---\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting Final Baseline Experiment on Split CIFAR-100 \" + \"=\"*20)\n",
        "\n",
        "# --- Use the best learning rate found from tuning ---\n",
        "CONFIG['lr'] = 0.001\n",
        "\n",
        "set_seed(CONFIG['seed'])\n",
        "\n",
        "model_baseline = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "optimizer = optim.Adam(model_baseline.parameters(), lr=CONFIG['lr'])\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "# --- Use a dictionary to store results  ---\n",
        "baseline_results = {\n",
        "    'accuracies': [],\n",
        "    'total_replay_batches': 0  # Allows us to better compare the replay methods. 0 for the baseline.\n",
        "}\n",
        "\n",
        "# --- Main Continual Learning Loop ---\n",
        "for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "    print(f\"\\n--- Training on Task {task_id+1}/{len(split_cifar100_benchmark.train_stream)} ---\")\n",
        "\n",
        "    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "    # --- Training loop ---\n",
        "    model_baseline.train()\n",
        "    for epoch in range(CONFIG['epochs_per_task']):\n",
        "        running_loss = 0.0\n",
        "        for data, targets, _ in train_loader:\n",
        "            data, targets = data.to(device), targets.to(device)\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            outputs = model_baseline(data)\n",
        "            loss = criterion(outputs, targets)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            running_loss += loss.item()\n",
        "\n",
        "        print(f\"Task {task_id+1}, Epoch {epoch+1}, Avg Loss: {running_loss / len(train_loader):.4f}\")\n",
        "\n",
        "    # --- Evaluation Step ---\n",
        "    accuracy = evaluate_on_seen_tasks(model_baseline, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "    baseline_results['accuracies'].append(accuracy)\n",
        "    print(f\"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "# --- Save Results and Checkpoint ---\n",
        "final_model_path = os.path.join(CONFIG['checkpoints_path'], 'baseline_cifar_final_model.pth')\n",
        "torch.save(model_baseline.state_dict(), final_model_path)\n",
        "print(f\"\\nFinal baseline model (CIFAR-100) saved to {final_model_path}\")\n",
        "\n",
        "# Save the entire results dictionary\n",
        "baseline_cifar_results_path = os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl')\n",
        "save_results(baseline_results, baseline_cifar_results_path)\n",
        "\n",
        "# --- Plot the Single Result ---\n",
        "plot_results(\n",
        "    {'Baseline (CIFAR-100)': baseline_results['accuracies']},\n",
        "    title=\"Baseline Performance on Split CIFAR-100\"\n",
        ")"
      ],
      "metadata": {
        "id": "dn5HVbWygPuo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: STANDARD EXPERIENCE REPLAY (ER) ON SPLIT CIFAR-100 ---\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting Standard ER Experiment on Split CIFAR-100 \" + \"=\"*20)\n",
        "\n",
        "CONFIG['buffer_capacity'] = 1000\n",
        "CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2\n",
        "\n",
        "set_seed(CONFIG['seed'])\n",
        "\n",
        "model_er = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "optimizer_er = optim.Adam(model_er.parameters(), lr=CONFIG['lr'])\n",
        "criterion_er = nn.CrossEntropyLoss()\n",
        "\n",
        "replay_buffer_er = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])\n",
        "\n",
        "er_results = {\n",
        "    'accuracies': [],\n",
        "    'total_replay_batches': 0\n",
        "}\n",
        "\n",
        "# --- Main Continual Learning Loop ---\n",
        "for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "    print(f\"\\n--- Training on Task {task_id+1}/{len(split_cifar100_benchmark.train_stream)} ---\")\n",
        "\n",
        "    # Populate replay buffer\n",
        "    print(f\"Populating replay buffer from Task {task_id+1}...\")\n",
        "    for data_point, target, _ in experience.dataset:\n",
        "        replay_buffer_er.add(data_point, target)\n",
        "    print(f\"Replay buffer size: {len(replay_buffer_er)}\")\n",
        "\n",
        "    # Create the dataloader for the current task\n",
        "    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "    # --- Training loop with mixed batches ---\n",
        "    model_er.train()\n",
        "    for epoch in range(CONFIG['epochs_per_task']):\n",
        "        running_loss = 0.0\n",
        "        for new_data, new_targets, _ in train_loader:\n",
        "\n",
        "            if len(replay_buffer_er) >= CONFIG['replay_batch_size']:\n",
        "                old_data, old_targets = replay_buffer_er.sample(CONFIG['replay_batch_size'])\n",
        "                new_data = new_data[:CONFIG['replay_batch_size']]\n",
        "                new_targets = new_targets[:CONFIG['replay_batch_size']]\n",
        "\n",
        "                combined_data = torch.cat((new_data, old_data), dim=0).to(device)\n",
        "                combined_targets = torch.cat((new_targets, old_targets), dim=0).to(device)\n",
        "\n",
        "                optimizer_er.zero_grad()\n",
        "                outputs = model_er(combined_data)\n",
        "                loss = criterion_er(outputs, combined_targets)\n",
        "                loss.backward()\n",
        "                optimizer_er.step()\n",
        "                running_loss += loss.item()\n",
        "\n",
        "                er_results['total_replay_batches'] += 1\n",
        "\n",
        "        print(f\"Task {task_id+1}, Epoch {epoch+1}, Avg Loss: {running_loss / len(train_loader):.4f}\")\n",
        "\n",
        "    # --- Evaluation ---\n",
        "    accuracy = evaluate_on_seen_tasks(model_er, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "    er_results['accuracies'].append(accuracy)\n",
        "    print(f\"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "# --- Save and Plot ---\n",
        "er_cifar_results_path = os.path.join(CONFIG['results_path'], 'er_cifar_results.pkl')\n",
        "save_results(er_results, er_cifar_results_path)\n",
        "\n",
        "# Load the dictionaries for plotting\n",
        "baseline_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))\n",
        "\n",
        "baseline_accuracy = baseline_results['accuracies']\n",
        "er_results_loaded = load_results(er_cifar_results_path)\n",
        "\n",
        "# Extract the accuracy lists for the plot function\n",
        "results_to_plot = {\n",
        "    'Baseline (CIFAR-100)': baseline_accuracy,\n",
        "    'Standard ER (CIFAR-100)': er_results_loaded['accuracies']\n",
        "}\n",
        "plot_results(results_to_plot, title=\"Standard ER vs. Baseline on Split CIFAR-100\")\n",
        "\n",
        "# print the efficiency metric\n",
        "print(f\"\\nStandard ER performed {er_results_loaded['total_replay_batches']} replay batches.\")"
      ],
      "metadata": {
        "id": "ilTLpP4FZmix"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: TFC-SR SENSITIVITY ANALYSIS ON SPLIT CIFAR-100 ---\n",
        "from utils import evaluate_replay_buffer\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting TFC-SR Hyperparameter Tuning on Split CIFAR-100 \" + \"=\"*20)\n",
        "\n",
        "# --- Hyperparameter Search Setup ---\n",
        "thresholds_to_try = [10.0, 20.0, 30.0, 50.0, 70.0, 90.0]\n",
        "all_tfc_results = {}\n",
        "\n",
        "CONFIG['buffer_capacity'] = 1000\n",
        "CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2\n",
        "CONFIG['initial_replay_gap'] = 1   # Start checking after epoch 1\n",
        "CONFIG['replay_gap_multiplier'] = 1.5 # How much to increase the gap\n",
        "\n",
        "# --- Outer loop for tuning the mastery_threshold ---\n",
        "for threshold in thresholds_to_try:\n",
        "    print(f\"\\n--->>> STARTING TRIAL: THRESHOLD = {threshold}% <<<---\")\n",
        "    CONFIG['mastery_threshold'] = threshold\n",
        "\n",
        "    set_seed(CONFIG['seed'])\n",
        "    model_tfc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "    optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])\n",
        "    criterion_tfc = nn.CrossEntropyLoss()\n",
        "    replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])\n",
        "\n",
        "    current_run_results = { 'accuracies': [], 'total_replay_batches': 0, 'memory_checks': 0, 'schedule_history': [] }\n",
        "\n",
        "    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "        print(f\"--> Training on Task {task_id+1}\")\n",
        "\n",
        "        for data_point, target, _ in experience.dataset: replay_buffer_tfc.add(data_point, target)\n",
        "\n",
        "        current_replay_gap = float(CONFIG['initial_replay_gap'])\n",
        "        replay_timer = int(current_replay_gap)\n",
        "        train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "        model_tfc.train()\n",
        "        for epoch in range(CONFIG['epochs_per_task']):\n",
        "            for new_data, new_targets, _ in train_loader:\n",
        "                 if len(replay_buffer_tfc) >= CONFIG['replay_batch_size']:\n",
        "                    old_data, old_targets = replay_buffer_tfc.sample(CONFIG['replay_batch_size'])\n",
        "                    new_data = new_data[:CONFIG['replay_batch_size']]\n",
        "\n",
        "                    combined_data = torch.cat((new_data, old_data), dim=0).to(device)\n",
        "                    combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)\n",
        "\n",
        "                    optimizer_tfc.zero_grad()\n",
        "                    outputs = model_tfc(combined_data)\n",
        "                    loss = criterion_tfc(outputs, combined_targets)\n",
        "                    loss.backward()\n",
        "                    optimizer_tfc.step()\n",
        "\n",
        "                    current_run_results['total_replay_batches'] += 1\n",
        "\n",
        "            # --- Adaptive Replay Scheduling Logic with DIAGNOSTICS ---\n",
        "            if (epoch + 1) == replay_timer and len(replay_buffer_tfc) > 1:\n",
        "                current_run_results['memory_checks'] += 1\n",
        "                model_tfc.eval()\n",
        "\n",
        "                print(f\"\\n  [Epoch {epoch+1}] Memory Check Triggered. Current Timer: {replay_timer}\")\n",
        "\n",
        "                replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)\n",
        "                print(f\"    Replay Buffer Perf: {replay_perf:.2f}%. Comparing against Threshold: {CONFIG['mastery_threshold']}%\")\n",
        "\n",
        "                if replay_perf >= CONFIG['mastery_threshold']:\n",
        "                    current_replay_gap *= CONFIG['replay_gap_multiplier']\n",
        "                    replay_timer += round(current_replay_gap)\n",
        "                    print(f\"    RESULT: Mastery MET. New timer set to epoch {replay_timer}.\")\n",
        "                else:\n",
        "                    replay_timer += 1\n",
        "                    print(f\"    RESULT: Mastery FAILED. New timer set to epoch {replay_timer}.\")\n",
        "\n",
        "                model_tfc.train()\n",
        "\n",
        "        accuracy = evaluate_on_seen_tasks(model_tfc, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "        current_run_results['accuracies'].append(accuracy)\n",
        "        print(f\"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "    all_tfc_results[threshold] = current_run_results\n",
        "\n",
        "# --- ANALYSIS, SAVING, AND PLOTTING ---\n",
        "\n",
        "# 1. Save the FULL dictionary of all trial runs for later analysis\n",
        "all_tfc_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_all_trials.pkl')\n",
        "save_results(all_tfc_results, all_tfc_results_path)\n",
        "print(f\"\\nFull TFC-SR tuning results saved to {all_tfc_results_path}\")\n",
        "\n",
        "# 2. Programmatically find the best result and save it separately for convenience\n",
        "best_threshold = max(all_tfc_results, key=lambda k: all_tfc_results[k]['accuracies'][-1])\n",
        "best_tfc_results = all_tfc_results[best_threshold]\n",
        "print(f\"\\nBest TFC-SR mastery threshold was {best_threshold}% with final accuracy {best_tfc_results['accuracies'][-1]:.2f}%\")\n",
        "best_tfc_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl')\n",
        "save_results(best_tfc_results, best_tfc_results_path)\n",
        "\n",
        "# 3. Plot the Sensitivity Analysis\n",
        "print(\"\\n--- Generating Sensitivity Analysis Plot ---\")\n",
        "# Load ER results to use as a reference line\n",
        "er_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_results.pkl'))\n",
        "sensitivity_plot_data = {f'TFC-SR (Thresh={t}%)': res['accuracies'] for t, res in all_tfc_results.items()}\n",
        "sensitivity_plot_data['Standard ER'] = er_cifar_results['accuracies']\n",
        "plot_results(sensitivity_plot_data, title=\"TFC-SR Sensitivity to Mastery Threshold on Split CIFAR-100\")\n",
        "\n",
        "# 4. Plot the Main Comparison\n",
        "print(\"\\n--- Generating Main Comparison Plot ---\")\n",
        "# Load all the \"best\" results files\n",
        "baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))\n",
        "tfc_cifar_best_results = load_results(best_tfc_results_path)\n",
        "\n",
        "main_plot_data = {\n",
        "    'Baseline': baseline_cifar_results['accuracies'],\n",
        "    'Standard ER': er_cifar_results['accuracies'],\n",
        "    f'TFC-SR (Ours, Thresh={best_threshold}%)': tfc_cifar_best_results['accuracies']\n",
        "}\n",
        "plot_results(main_plot_data, title=\"Main Performance Comparison on Split CIFAR-100\")\n",
        "\n",
        "# 5. Report the efficiency metrics\n",
        "print(f\"\\n--- Efficiency Comparison ---\")\n",
        "print(f\"Standard ER performed {er_cifar_results['total_replay_batches']} replay batches.\")\n",
        "print(f\"Best TFC-SR (Thresh={best_threshold}%) performed {best_tfc_results['total_replay_batches']} replay batches with {best_tfc_results['memory_checks']} memory checks.\")"
      ],
      "metadata": {
        "id": "v3mr3-1UPfBY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: STANDARD ER SENSITIVITY TO BUFFER SIZE (OPTIMIZED) ---\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting ER Buffer Size Tuning \" + \"=\"*20)\n",
        "\n",
        "# --- Hyperparameter Search Setup ---\n",
        "new_buffer_sizes_to_try = [100, 500, 2000]\n",
        "er_tuning_results = {}\n",
        "CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2\n",
        "\n",
        "# --- Outer loop for tuning the buffer_capacity ---\n",
        "for capacity in new_buffer_sizes_to_try:\n",
        "    print(f\"\\n--->>> STARTING TRIAL: BUFFER CAPACITY = {capacity} <<<---\")\n",
        "\n",
        "    set_seed(CONFIG['seed'])\n",
        "    model_er = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "    optimizer_er = optim.Adam(model_er.parameters(), lr=CONFIG['lr'])\n",
        "    criterion_er = nn.CrossEntropyLoss()\n",
        "\n",
        "    replay_buffer_er = ReservoirReplayBuffer(capacity=capacity)\n",
        "\n",
        "    current_run_results = { 'accuracies': [], 'total_replay_batches': 0 }\n",
        "\n",
        "    # --- Main Continual Learning Loop ---\n",
        "    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "        print(f\"--> Training on Task {task_id+1}\")\n",
        "\n",
        "        # Populate replay buffer\n",
        "        for data_point, target, _ in experience.dataset:\n",
        "            replay_buffer_er.add(data_point, target)\n",
        "\n",
        "        train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "        # --- Training loop with mixed batches ---\n",
        "        model_er.train()\n",
        "        for epoch in range(CONFIG['epochs_per_task']):\n",
        "            for new_data, new_targets, _ in train_loader:\n",
        "                if len(replay_buffer_er) >= CONFIG['replay_batch_size']:\n",
        "                    old_data, old_targets = replay_buffer_er.sample(CONFIG['replay_batch_size'])\n",
        "                    new_data = new_data[:CONFIG['replay_batch_size']]\n",
        "\n",
        "                    combined_data = torch.cat((new_data, old_data), dim=0).to(device)\n",
        "                    combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)\n",
        "\n",
        "                    optimizer_er.zero_grad()\n",
        "                    outputs = model_er(combined_data)\n",
        "                    loss = criterion_er(outputs, combined_targets)\n",
        "                    loss.backward()\n",
        "                    optimizer_er.step()\n",
        "                    current_run_results['total_replay_batches'] += 1\n",
        "\n",
        "        # --- Unified Evaluation ---\n",
        "        accuracy = evaluate_on_seen_tasks(model_er, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "        current_run_results['accuracies'].append(accuracy)\n",
        "        print(f\"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "    # Store the final results for this buffer size\n",
        "    er_tuning_results[capacity] = current_run_results\n",
        "\n",
        "# --- ANALYSIS, SAVING, AND PLOTTING ---\n",
        "\n",
        "# 1. Load the result for the run we already completed\n",
        "path_to_er_1000_results = os.path.join(CONFIG['results_path'], 'er_cifar_results.pkl')\n",
        "try:\n",
        "    er_1000_results = load_results(path_to_er_1000_results)\n",
        "    # Add the loaded result to our tuning dictionary\n",
        "    er_tuning_results[1000] = er_1000_results\n",
        "    print(\"\\nSuccessfully loaded existing results for buffer size 1000.\")\n",
        "except FileNotFoundError:\n",
        "    print(\"\\nWarning: Could not find existing results for buffer size 1000. It will be missing from the plot.\")\n",
        "\n",
        "\n",
        "# 2. Save the FULL dictionary of all trial runs (new and old) for later analysis\n",
        "all_er_results_path = os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl')\n",
        "save_results(er_tuning_results, all_er_results_path)\n",
        "print(f\"Full ER buffer tuning results saved to {all_er_results_path}\")\n",
        "\n",
        "# 3. Plot the sensitivity analysis curve\n",
        "print(\"\\n--- Generating Buffer Size Sensitivity Plot for Standard ER ---\")\n",
        "# Sort the dictionary by key (buffer size) for a clean plot\n",
        "sorted_capacities = sorted(er_tuning_results.keys())\n",
        "final_accuracies = [er_tuning_results[cap]['accuracies'][-1] for cap in sorted_capacities]"
      ],
      "metadata": {
        "id": "T9rwTv0e0DoU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "\n",
        "plt.figure(figsize=(8, 5))\n",
        "plt.plot(sorted_capacities, final_accuracies, marker='o')\n",
        "plt.title(\"Standard ER Performance vs. Buffer Capacity on Split CIFAR-100\")\n",
        "plt.xlabel(\"Replay Buffer Capacity\")\n",
        "plt.ylabel(\"Final Average Accuracy (%) after 10 Tasks\")\n",
        "plt.xscale('log')\n",
        "plt.grid(True, which='both', linestyle='--')\n",
        "plt.show()\n",
        "\n",
        "# 4. Plot all the learning curves\n",
        "print(\"\\n--- Generating Learning Curves for Each Buffer Size ---\")\n",
        "plot_data_er = {f'ER (Buffer={cap})': er_tuning_results[cap]['accuracies'] for cap in sorted_capacities}\n",
        "plot_results(plot_data_er, title=\"Standard ER Learning Curves by Buffer Size\")"
      ],
      "metadata": {
        "id": "anv8ORwV024R"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: TFC-SR SENSITIVITY TO BUFFER SIZE ON SPLIT CIFAR-100 ---\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting TFC-SR Buffer Size Tuning on Split CIFAR-100 \" + \"=\"*20)\n",
        "\n",
        "# --- Hyperparameter Search Setup ---\n",
        "buffer_sizes_to_try = [100, 500, 2000]\n",
        "all_tfc_buffer_results = {}\n",
        "\n",
        "# --- Fixed Hyperparameters for this experiment ---\n",
        "CONFIG['mastery_threshold'] = 10.0 # best threshold we found previously\n",
        "CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2\n",
        "CONFIG['initial_replay_gap'] = 1\n",
        "CONFIG['replay_gap_multiplier'] = 1.5\n",
        "\n",
        "# --- Outer loop for tuning the buffer_capacity ---\n",
        "for buffer_size in buffer_sizes_to_try:\n",
        "    print(f\"\\n--->>> STARTING TRIAL: BUFFER SIZE = {buffer_size} <<<---\")\n",
        "    # Set the buffer size for this run\n",
        "    CONFIG['buffer_capacity'] = buffer_size\n",
        "\n",
        "    # --- Setup for this specific trial ---\n",
        "    set_seed(CONFIG['seed'])\n",
        "    model_tfc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "    optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])\n",
        "    criterion_tfc = nn.CrossEntropyLoss()\n",
        "    replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])\n",
        "\n",
        "    current_run_results = { 'accuracies': [], 'total_replay_batches': 0, 'memory_checks': 0, 'schedule_history': [] }\n",
        "\n",
        "    # --- Main Continual Learning Loop ---\n",
        "    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "        print(f\"--> Training on Task {task_id+1}\")\n",
        "        for data_point, target, _ in experience.dataset: replay_buffer_tfc.add(data_point, target)\n",
        "\n",
        "        current_replay_gap = float(CONFIG['initial_replay_gap'])\n",
        "        replay_timer = int(current_replay_gap)\n",
        "        train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "        model_tfc.train()\n",
        "        for epoch in range(CONFIG['epochs_per_task']):\n",
        "          for new_data, new_targets, _ in train_loader:\n",
        "                if len(replay_buffer_tfc) >= CONFIG['replay_batch_size']:\n",
        "                  old_data, old_targets = replay_buffer_tfc.sample(CONFIG['replay_batch_size'])\n",
        "                  new_data = new_data[:CONFIG['replay_batch_size']]\n",
        "\n",
        "                  combined_data = torch.cat((new_data, old_data), dim=0).to(device)\n",
        "                  combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)\n",
        "\n",
        "                  optimizer_tfc.zero_grad()\n",
        "                  outputs = model_tfc(combined_data)\n",
        "                  loss = criterion_tfc(outputs, combined_targets)\n",
        "                  loss.backward()\n",
        "                  optimizer_tfc.step()\n",
        "\n",
        "                  current_run_results['total_replay_batches'] += 1\n",
        "\n",
        "          # --- Adaptive Replay Scheduling Logic with DIAGNOSTICS ---\n",
        "          if (epoch + 1) == replay_timer and len(replay_buffer_tfc) > 1:\n",
        "                current_run_results['memory_checks'] += 1\n",
        "                model_tfc.eval()\n",
        "\n",
        "                print(f\"\\n  [Epoch {epoch+1}] Memory Check Triggered. Current Timer: {replay_timer}\")\n",
        "\n",
        "                replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)\n",
        "                print(f\"    Replay Buffer Perf: {replay_perf:.2f}%. Comparing against Threshold: {CONFIG['mastery_threshold']}%\")\n",
        "\n",
        "                if replay_perf >= CONFIG['mastery_threshold']:\n",
        "                    current_replay_gap *= CONFIG['replay_gap_multiplier']\n",
        "                    replay_timer += round(current_replay_gap)\n",
        "                    print(f\"    RESULT: Mastery MET. New timer set to epoch {replay_timer}.\")\n",
        "                else:\n",
        "                    replay_timer += 1\n",
        "                    print(f\"    RESULT: Mastery FAILED. New timer set to epoch {replay_timer}.\")\n",
        "\n",
        "                model_tfc.train()\n",
        "\n",
        "        accuracy = evaluate_on_seen_tasks(model_tfc, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "        current_run_results['accuracies'].append(accuracy)\n",
        "        print(f\"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "    all_tfc_buffer_results[buffer_size] = current_run_results\n",
        "\n",
        "# --- ANALYSIS, SAVING, AND PLOTTING (This is the corrected part) ---\n",
        "\n",
        "# 1. Load the result for the run with buffer size 1000\n",
        "path_to_tfc_1000_results = os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl')\n",
        "try:\n",
        "    tfc_1000_results = load_results(path_to_tfc_1000_results)\n",
        "    all_tfc_buffer_results[1000] = tfc_1000_results\n",
        "    print(\"\\nSuccessfully loaded existing results for TFC-SR with buffer size 1000.\")\n",
        "except FileNotFoundError:\n",
        "    print(f\"\\nWarning: Could not find results file at {path_to_tfc_1000_results}. It will be missing.\")\n",
        "\n",
        "# 2. Save the FULL dictionary of all trial runs\n",
        "all_tfc_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_buffer_tuning_ALL.pkl')\n",
        "save_results(all_tfc_buffer_results, all_tfc_results_path)\n",
        "print(f\"Full TFC-SR buffer tuning results saved to {all_tfc_results_path}\")\n",
        "\n",
        "# 3. Plot the TFC-SR Sensitivity to Buffer Size\n",
        "print(\"\\n--- Generating Buffer Size Sensitivity Plot for TFC-SR ---\")\n",
        "# Load the corresponding Standard ER results for comparison\n",
        "er_buffer_tuning_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))\n",
        "\n",
        "# Get the final accuracy for each buffer size for both methods\n",
        "sorted_capacities = sorted(all_tfc_buffer_results.keys())\n",
        "tfc_final_accuracies = [all_tfc_buffer_results[cap]['accuracies'][-1] for cap in sorted_capacities]\n",
        "er_final_accuracies = [er_buffer_tuning_results[cap]['accuracies'][-1] for cap in sorted_capacities]\n",
        "\n",
        "plt.figure(figsize=(8, 5))\n",
        "plt.plot(sorted_capacities, tfc_final_accuracies, marker='o', label='TFC-SR (Ours)')\n",
        "plt.plot(sorted_capacities, er_final_accuracies, marker='o', linestyle='--', label='Standard ER')\n",
        "plt.title(\"Performance vs. Buffer Capacity on Split CIFAR-100\")\n",
        "plt.xlabel(\"Replay Buffer Capacity\")\n",
        "plt.ylabel(\"Final Average Accuracy (%) after 10 Tasks\")\n",
        "plt.xscale('log')\n",
        "plt.grid(True, which='both', linestyle='--')\n",
        "plt.legend()\n",
        "plt.show()\n",
        "\n",
        "# 4. Plot the Main Comparison using the best buffer size for TFC-SR\n",
        "# Find which buffer size gave TFC-SR the best final accuracy\n",
        "best_buffer_size = max(all_tfc_buffer_results, key=lambda k: all_tfc_buffer_results[k]['accuracies'][-1])\n",
        "best_tfc_results = all_tfc_buffer_results[best_buffer_size]\n",
        "\n",
        "# Load other baselines\n",
        "baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))\n",
        "er_best_buffer_results = er_buffer_tuning_results[best_buffer_size] # Compare ER at the same buffer size\n",
        "\n",
        "main_plot_data = {\n",
        "    'Baseline': baseline_cifar_results['accuracies'],\n",
        "    f'Standard ER (Buffer={best_buffer_size})': er_best_buffer_results['accuracies'],\n",
        "    f'TFC-SR (Ours, Buffer={best_buffer_size})': best_tfc_results['accuracies']\n",
        "}\n",
        "plot_results(main_plot_data, title=\"Main Performance Comparison on Split CIFAR-100\")"
      ],
      "metadata": {
        "id": "I3iGgKYTBHlS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(\"\\n\" + \"=\"*20 + \" Starting EWC Tuning on Split CIFAR-100 \" + \"=\"*20)\n",
        "\n",
        "# --- Hyperparameter Search Setup ---\n",
        "ewc_lambdas_to_try = [1000.0, 10000.0, 100000.0]\n",
        "all_ewc_cifar_results = {}\n",
        "\n",
        "# Use the best learning rate we found for the baseline\n",
        "current_lr = CONFIG.get('lr', 0.001)\n",
        "\n",
        "# --- Outer loop for tuning lambda ---\n",
        "for lmbda in ewc_lambdas_to_try:\n",
        "    print(f\"\\n--->>> STARTING EWC TRIAL: LAMBDA = {lmbda} <<<---\")\n",
        "\n",
        "    set_seed(CONFIG['seed'])\n",
        "\n",
        "    model_ewc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "    optimizer_ewc = optim.Adam(model_ewc.parameters(), lr=current_lr)\n",
        "\n",
        "    # Instantiate the EWC strategy with the current lambda\n",
        "    ewc_strategy = EWC(\n",
        "        model_ewc, optimizer_ewc, nn.CrossEntropyLoss(),\n",
        "        ewc_lambda=lmbda,\n",
        "        train_mb_size=CONFIG['batch_size'],\n",
        "        train_epochs=CONFIG['epochs_per_task'],\n",
        "        device=device\n",
        "    )\n",
        "\n",
        "    current_lambda_accuracies = []\n",
        "\n",
        "    # --- Training and Evaluation Loop ---\n",
        "    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "        print(f\"--> Training on Task {task_id+1}\")\n",
        "\n",
        "        ewc_strategy.train(experience)\n",
        "\n",
        "        accuracy = evaluate_on_seen_tasks(\n",
        "            ewc_strategy.model,\n",
        "            split_cifar100_benchmark,\n",
        "            task_id,\n",
        "            device,\n",
        "            CONFIG['batch_size']\n",
        "        )\n",
        "        current_lambda_accuracies.append(accuracy)\n",
        "        print(f\"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "    all_ewc_cifar_results[lmbda] = current_lambda_accuracies\n",
        "\n",
        "# --- Find the best EWC result and save it ---\n",
        "best_lambda_ewc = max(all_ewc_cifar_results, key=lambda k: all_ewc_cifar_results[k][-1])\n",
        "best_ewc_accuracies = all_ewc_cifar_results[best_lambda_ewc]\n",
        "\n",
        "print(f\"\\nBest EWC lambda for CIFAR-100 was {best_lambda_ewc} with final accuracy: {best_ewc_accuracies[-1]:.2f}%\")\n",
        "\n",
        "# --- SAVE AND PLOT ---\n",
        "# Save the results for the best EWC run\n",
        "ewc_cifar_results_path = os.path.join(CONFIG['results_path'], 'ewc_cifar_best.pkl')\n",
        "save_results({'accuracies': best_ewc_accuracies}, ewc_cifar_results_path)\n",
        "print(f\"Best EWC (CIFAR-100) results saved to {ewc_cifar_results_path}\")\n",
        "\n",
        "# Load all other \"best\" results to create the final comparison plot\n",
        "baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))\n",
        "er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Assuming 1000 was best\n",
        "tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_buffer_tuning_ALL.pkl'))[1000] # Assuming 1000 was best\n",
        "\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_cifar_results['accuracies'],\n",
        "    'Standard ER': er_cifar_best_results['accuracies'],\n",
        "    f'EWC (Best λ)': best_ewc_accuracies,\n",
        "    'TFC-SR (Ours)': tfc_cifar_best_results['accuracies']\n",
        "}\n",
        "plot_results(results_to_plot, title=\"Main Performance Comparison on Split CIFAR-100\")"
      ],
      "metadata": {
        "id": "GfmAvYe2CttF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: SI HYPERPARAMETER TUNING ON SPLIT CIFAR-100 ---\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting SI Tuning on Split CIFAR-100 \" + \"=\"*20)\n",
        "\n",
        "# --- Hyperparameter Search Setup ---\n",
        "si_lambdas_to_try = [1.0, 10.0, 100.0, 1000.0]\n",
        "all_si_cifar_results = {}\n",
        "\n",
        "# Use the best learning rate we found for the baseline\n",
        "current_lr = CONFIG.get('lr', 0.001)\n",
        "\n",
        "# --- Outer loop for tuning lambda ---\n",
        "for lmbda in si_lambdas_to_try:\n",
        "    print(f\"\\n--->>> STARTING SI TRIAL: LAMBDA = {lmbda} <<<---\")\n",
        "\n",
        "    set_seed(CONFIG['seed'])\n",
        "\n",
        "    model_si = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "    optimizer_si = optim.Adam(model_si.parameters(), lr=current_lr)\n",
        "\n",
        "    # Instantiate the SynapticIntelligence strategy\n",
        "    si_strategy = SynapticIntelligence(\n",
        "        model_si, optimizer_si, nn.CrossEntropyLoss(),\n",
        "        si_lambda=lmbda,\n",
        "        train_mb_size=CONFIG['batch_size'],\n",
        "        train_epochs=CONFIG['epochs_per_task'],\n",
        "        device=device\n",
        "    )\n",
        "\n",
        "    current_lambda_accuracies = []\n",
        "\n",
        "    # --- Training and Evaluation Loop ---\n",
        "    for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "        print(f\"--> Training on Task {task_id+1}\")\n",
        "\n",
        "        si_strategy.train(experience)\n",
        "\n",
        "        accuracy = evaluate_on_seen_tasks(\n",
        "            si_strategy.model,\n",
        "            split_cifar100_benchmark,\n",
        "            task_id,\n",
        "            device,\n",
        "            CONFIG['batch_size']\n",
        "        )\n",
        "        current_lambda_accuracies.append(accuracy)\n",
        "        print(f\"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "    all_si_cifar_results[lmbda] = current_lambda_accuracies\n",
        "\n",
        "# --- Find the best SI result and save it ---\n",
        "best_lambda_si = max(all_si_cifar_results, key=lambda k: all_si_cifar_results[k][-1])\n",
        "best_si_accuracies = all_si_cifar_results[best_lambda_si]\n",
        "\n",
        "print(f\"\\nBest SI lambda for CIFAR-100 was {best_lambda_si} with final accuracy: {best_si_accuracies[-1]:.2f}%\")\n",
        "\n",
        "# --- SAVE AND PLOT ---\n",
        "# Save the results for the best SI run\n",
        "si_cifar_results_path = os.path.join(CONFIG['results_path'], 'si_cifar_best.pkl')\n",
        "save_results({'accuracies': best_si_accuracies}, si_cifar_results_path)\n",
        "print(f\"Best SI (CIFAR-100) results saved to {si_cifar_results_path}\")\n",
        "\n",
        "# Load all other \"best\" results to create the final comparison plot\n",
        "baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))\n",
        "er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Use your best ER run\n",
        "ewc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'ewc_cifar_best.pkl'))\n",
        "tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl'))\n",
        "\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_cifar_results['accuracies'],\n",
        "    'Standard ER': er_cifar_best_results['accuracies'],\n",
        "    f'EWC (Best λ={best_lambda_ewc})': ewc_cifar_best_results['accuracies'],\n",
        "    f'SI (Best λ={best_lambda_si})': best_si_accuracies,\n",
        "    'TFC-SR': tfc_cifar_best_results['accuracies']\n",
        "}\n",
        "plot_results(results_to_plot, title=\"Main Performance Comparison on Split CIFAR-100\")"
      ],
      "metadata": {
        "id": "9NwHyfl8Rtvo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: TFC-SR STESS TEST ON SPLIT CIFAR-100 ---\n",
        "from utils import evaluate_replay_buffer\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting TFC-SR Stress Test on Split CIFAR-100 \" + \"=\"*20)\n",
        "\n",
        "\n",
        "CONFIG['buffer_capacity'] = 1000\n",
        "CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2\n",
        "CONFIG['initial_replay_gap'] = 1   # Start checking after epoch 1\n",
        "CONFIG['replay_gap_multiplier'] = 1.5 # How much to increase the gap\n",
        "threshold = 99.0\n",
        "\n",
        "print(f\"\\n--->>> STARTING TRIAL: THRESHOLD = {threshold}% <<<---\")\n",
        "CONFIG['mastery_threshold'] = threshold\n",
        "\n",
        "set_seed(CONFIG['seed'])\n",
        "model_tfc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])\n",
        "criterion_tfc = nn.CrossEntropyLoss()\n",
        "replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])\n",
        "\n",
        "current_run_results = { 'accuracies': [], 'total_replay_batches': 0, 'memory_checks': 0, 'schedule_history': [] }\n",
        "\n",
        "for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "    print(f\"--> Training on Task {task_id+1}\")\n",
        "\n",
        "    for data_point, target, _ in experience.dataset: replay_buffer_tfc.add(data_point, target)\n",
        "\n",
        "    current_replay_gap = float(CONFIG['initial_replay_gap'])\n",
        "    replay_timer = int(current_replay_gap)\n",
        "    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "    model_tfc.train()\n",
        "    for epoch in range(CONFIG['epochs_per_task']):\n",
        "        for new_data, new_targets, _ in train_loader:\n",
        "              if len(replay_buffer_tfc) >= CONFIG['replay_batch_size']:\n",
        "                old_data, old_targets = replay_buffer_tfc.sample(CONFIG['replay_batch_size'])\n",
        "                new_data = new_data[:CONFIG['replay_batch_size']]\n",
        "\n",
        "                combined_data = torch.cat((new_data, old_data), dim=0).to(device)\n",
        "                combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)\n",
        "\n",
        "                optimizer_tfc.zero_grad()\n",
        "                outputs = model_tfc(combined_data)\n",
        "                loss = criterion_tfc(outputs, combined_targets)\n",
        "                loss.backward()\n",
        "                optimizer_tfc.step()\n",
        "\n",
        "                current_run_results['total_replay_batches'] += 1\n",
        "\n",
        "        # --- Adaptive Replay Scheduling Logic with DIAGNOSTICS ---\n",
        "        if (epoch + 1) == replay_timer and len(replay_buffer_tfc) > 1:\n",
        "            current_run_results['memory_checks'] += 1\n",
        "            model_tfc.eval()\n",
        "\n",
        "            print(f\"\\n  [Epoch {epoch+1}] Memory Check Triggered. Current Timer: {replay_timer}\")\n",
        "\n",
        "            replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)\n",
        "            print(f\"    Replay Buffer Perf: {replay_perf:.2f}%. Comparing against Threshold: {CONFIG['mastery_threshold']}%\")\n",
        "\n",
        "            if replay_perf >= CONFIG['mastery_threshold']:\n",
        "                current_replay_gap *= CONFIG['replay_gap_multiplier']\n",
        "                replay_timer += round(current_replay_gap)\n",
        "                print(f\"    RESULT: Mastery MET. New timer set to epoch {replay_timer}.\")\n",
        "            else:\n",
        "                replay_timer += 1\n",
        "                print(f\"    RESULT: Mastery FAILED. New timer set to epoch {replay_timer}.\")\n",
        "\n",
        "            model_tfc.train()\n",
        "\n",
        "    accuracy = evaluate_on_seen_tasks(model_tfc, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "    current_run_results['accuracies'].append(accuracy)\n",
        "    print(f\"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "# --- ANALYSIS, SAVING, AND PLOTTING ---\n",
        "tfc_stress_cifar_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_99_thresh_cifar.pkl')\n",
        "save_results(current_run_results, tfc_stress_cifar_results_path)\n",
        "print(f\"TFC_SR Stress Test (CIFAR-100) results saved to {tfc_stress_cifar_results_path}\")\n",
        "\n",
        "baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))\n",
        "er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Use your best ER run\n",
        "tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl'))\n",
        "\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_cifar_results['accuracies'],\n",
        "    'Standard ER': er_cifar_best_results['accuracies'],\n",
        "    'TFC-SR (threshold = 10.0)': tfc_cifar_best_results['accuracies'],\n",
        "    f'TFC-SR (threshold = {threshold})': current_run_results['accuracies'],\n",
        "}\n",
        "plot_results(results_to_plot, title=\"TFC-SR Stress Test on Split CIFAR-100\")"
      ],
      "metadata": {
        "id": "cJ3yrNuilQju"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: TFC-SR with Spaced Replay ON SPLIT CIFAR-100 ---\n",
        "from utils import evaluate_replay_buffer\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting TFC-SR Spaced Replay on Split CIFAR-100 \" + \"=\"*20)\n",
        "\n",
        "\n",
        "CONFIG['buffer_capacity'] = 1000\n",
        "CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2\n",
        "CONFIG['initial_replay_gap'] = 1   # Start checking after epoch 1\n",
        "CONFIG['replay_gap_multiplier'] = 1.5 # How much to increase the gap\n",
        "threshold = 10.0\n",
        "\n",
        "print(f\"\\n--->>> STARTING TRIAL: THRESHOLD = {threshold}% <<<---\")\n",
        "CONFIG['mastery_threshold'] = threshold\n",
        "\n",
        "set_seed(CONFIG['seed'])\n",
        "model_tfc = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])\n",
        "criterion_tfc = nn.CrossEntropyLoss()\n",
        "replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])\n",
        "\n",
        "current_run_results = { 'accuracies': [], 'total_replay_batches': 0, 'memory_checks': 0, 'schedule_history': [] }\n",
        "\n",
        "for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "    print(f\"--> Training on Task {task_id+1}\")\n",
        "\n",
        "    for data_point, target, _ in experience.dataset: replay_buffer_tfc.add(data_point, target)\n",
        "\n",
        "    current_replay_gap = float(CONFIG['initial_replay_gap'])\n",
        "    replay_timer = int(current_replay_gap)\n",
        "    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "    model_tfc.train()\n",
        "    for epoch in range(CONFIG['epochs_per_task']):\n",
        "        # --- Check if this is a Replay Epoch ---\n",
        "        if (epoch + 1) == replay_timer and task_id > 0: # Only replay after the first task\n",
        "            model_tfc.train()\n",
        "            print(f\"\\n--- Epoch {epoch+1}: Performing Spaced Replay & Memory Check ---\")\n",
        "\n",
        "            for new_data, new_targets, _ in train_loader:\n",
        "                if len(replay_buffer_tfc) >= CONFIG['replay_batch_size']:\n",
        "                    old_data, old_targets = replay_buffer_tfc.sample(CONFIG['replay_batch_size'])\n",
        "                    new_data = new_data[:CONFIG['replay_batch_size']]\n",
        "\n",
        "                    combined_data = torch.cat((new_data, old_data), dim=0).to(device)\n",
        "                    combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)\n",
        "\n",
        "                    optimizer_tfc.zero_grad()\n",
        "                    outputs = model_tfc(combined_data)\n",
        "                    loss = criterion_tfc(outputs, combined_targets)\n",
        "                    loss.backward()\n",
        "                    optimizer_tfc.step()\n",
        "\n",
        "                    current_run_results['total_replay_batches'] += 1\n",
        "\n",
        "            # After the replay epoch, we perform the memory check to schedule the NEXT replay\n",
        "            current_run_results['memory_checks'] += 1\n",
        "\n",
        "            model_tfc.eval()\n",
        "            print(f\"\\n  [Epoch {epoch+1}] Memory Check Triggered. Current Timer: {replay_timer}\")\n",
        "            replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)\n",
        "            # Update replay_timer based on replay_perf...\n",
        "            if replay_perf >= CONFIG['mastery_threshold']:\n",
        "                current_replay_gap *= CONFIG['replay_gap_multiplier']\n",
        "                replay_timer += round(current_replay_gap)\n",
        "                print(f\"    RESULT: Mastery MET. New timer set to epoch {replay_timer}.\")\n",
        "            else:\n",
        "                replay_timer += 1\n",
        "                print(f\"    RESULT: Mastery FAILED. New timer set to epoch {replay_timer}.\")\n",
        "        else:\n",
        "            model_tfc.train()\n",
        "            print(f\"\\n--- Epoch {epoch+1}: Training on New Task Data Only ---\")\n",
        "\n",
        "            for new_data, new_targets, _ in train_loader:\n",
        "                new_data, new_targets = new_data.to(device), new_targets.to(device)\n",
        "\n",
        "                optimizer_tfc.zero_grad()\n",
        "                outputs = model_tfc(new_data)\n",
        "                loss = criterion_tfc(outputs, new_targets)\n",
        "                loss.backward()\n",
        "                optimizer_tfc.step()\n",
        "\n",
        "    accuracy = evaluate_on_seen_tasks(model_tfc, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "    current_run_results['accuracies'].append(accuracy)\n",
        "    print(f\"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "# --- ANALYSIS, SAVING, AND PLOTTING ---\n",
        "tfc_spaced_replay_cifar_results_path = os.path.join(CONFIG['results_path'], 'tfc_spaced_replay_cifar.pkl')\n",
        "save_results(current_run_results, tfc_spaced_replay_cifar_results_path)\n",
        "print(f\"TFC_SR with spaced replay (CIFAR-100) results saved to {tfc_spaced_replay_cifar_results_path}\")\n",
        "\n",
        "baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))\n",
        "er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Use your best ER run\n",
        "tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl'))\n",
        "\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_cifar_results['accuracies'],\n",
        "    'Standard ER': er_cifar_best_results['accuracies'],\n",
        "    'TFC-SR': tfc_cifar_best_results['accuracies'],\n",
        "    f'TFC-SR (with Spaced Replay)': current_run_results['accuracies'],\n",
        "}\n",
        "plot_results(results_to_plot, title=\"TFC-SR with Spaced Replay vs Other methods on Split CIFAR-100\")\n",
        "\n",
        "print(\"Results for Spaced Replay:\")\n",
        "print(f\"Total Replay Batches: {current_run_results['total_replay_batches']}\")\n",
        "print(f\"Memory Checks: {current_run_results['memory_checks']}\")"
      ],
      "metadata": {
        "id": "cN0aaGWrMRyI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: Mastery-Gated Progression (MGP) ---\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting MGP Experiment \" + \"=\"*20)\n",
        "\n",
        "# --- MGP-Specific Hyperparameters in CONFIG ---\n",
        "CONFIG['new_task_mastery_thresh'] = 85.0 # e.g., Must get >90% on current task's test set\n",
        "CONFIG['retention_thresh'] = 15.0      # e.g., Must keep >15% avg accuracy on replay buffer\n",
        "CONFIG['max_epochs_per_task'] = 50    # A safety break to prevent infinite loops\n",
        "\n",
        "set_seed(CONFIG['seed'])\n",
        "\n",
        "# --- Setup for the experiment ---\n",
        "model_mgp = get_resnet18_for_cifar(num_classes=CONFIG['num_classes']).to(device)\n",
        "optimizer_mgp = optim.Adam(model_mgp.parameters(), lr=CONFIG['lr'])\n",
        "criterion_mgp = nn.CrossEntropyLoss()\n",
        "replay_buffer_mgp = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])\n",
        "\n",
        "# --- Results Dictionary ---\n",
        "mgp_results = {\n",
        "    'accuracies': [],\n",
        "    'epochs_per_task': [], # Track how long each task took\n",
        "    'final_accuracy': 0.0\n",
        "}\n",
        "\n",
        "# --- Main Continual Learning Loop ---\n",
        "for task_id, experience in enumerate(split_cifar100_benchmark.train_stream):\n",
        "    print(f\"\\n--- Starting to learn Task {task_id+1} ---\")\n",
        "\n",
        "    # Populate replay buffer with the new task's data\n",
        "    for data, target, _ in experience.dataset:\n",
        "        replay_buffer_mgp.add(data, target)\n",
        "\n",
        "    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "    # --- The \"Practice Until Mastery\" While Loop ---\n",
        "    epoch_count = 0\n",
        "    mastery_achieved = False\n",
        "    while not mastery_achieved and epoch_count < CONFIG['max_epochs_per_task']:\n",
        "        epoch_count += 1\n",
        "        model_mgp.train()\n",
        "\n",
        "        # --- Training is always on mixed batches (except for task 1) ---\n",
        "        for new_data, new_targets, _ in train_loader:\n",
        "            if task_id > 0 and len(replay_buffer_mgp) >= CONFIG['replay_batch_size']:\n",
        "                # Mixed batch training\n",
        "                old_data, old_targets = replay_buffer_mgp.sample(CONFIG['replay_batch_size'])\n",
        "                new_data = new_data[:CONFIG['replay_batch_size']]\n",
        "                combined_data = torch.cat((new_data, old_data), dim=0).to(device)\n",
        "                combined_targets = torch.cat((new_targets[:len(new_data)], old_targets), dim=0).to(device)\n",
        "\n",
        "                optimizer_mgp.zero_grad()\n",
        "                outputs = model_mgp(combined_data)\n",
        "                loss = criterion_mgp(outputs, combined_targets)\n",
        "                loss.backward()\n",
        "                optimizer_mgp.step()\n",
        "            else:\n",
        "                # For Task 1, train on new data only\n",
        "                new_data, new_targets = new_data.to(device), new_targets.to(device)\n",
        "                optimizer_mgp.zero_grad()\n",
        "                outputs = model_mgp(new_data)\n",
        "                loss = criterion_mgp(outputs, new_targets)\n",
        "                loss.backward()\n",
        "                optimizer_mgp.step()\n",
        "\n",
        "        # --- Mastery Check at the end of each epoch ---\n",
        "        # 1. Check performance on the CURRENT task's test set\n",
        "        current_task_test_loader = DataLoader(split_cifar100_benchmark.test_stream[task_id].dataset, batch_size=CONFIG['batch_size'])\n",
        "        model_mgp.eval()\n",
        "        correct, total = 0, 0\n",
        "        with torch.no_grad():\n",
        "            for data, targets, _ in current_task_test_loader:\n",
        "                data, targets = data.to(device), targets.to(device)\n",
        "                outputs = model_mgp(data)\n",
        "                _, predicted = torch.max(outputs.data, 1)\n",
        "                total += targets.size(0)\n",
        "                correct += (predicted == targets).sum().item()\n",
        "        new_task_perf = 100.0 * correct / total\n",
        "\n",
        "        # 2. Check performance on the replay buffer (if not the first task)\n",
        "        retention_perf = 100.0\n",
        "        if task_id > 0:\n",
        "            retention_perf = evaluate_replay_buffer(model_mgp, replay_buffer_mgp, device)\n",
        "\n",
        "        print(f\"Epoch {epoch_count}: New Task Perf: {new_task_perf:.2f}%, Retention Perf: {retention_perf:.2f}%\")\n",
        "\n",
        "        # 3. Check if both conditions are met\n",
        "        if new_task_perf >= CONFIG['new_task_mastery_thresh'] and retention_perf >= CONFIG['retention_thresh']:\n",
        "            mastery_achieved = True\n",
        "            print(f\"*** Mastery achieved for Task {task_id+1} in {epoch_count} epochs! ***\")\n",
        "\n",
        "    if not mastery_achieved:\n",
        "        print(f\"!!! Max epochs reached for Task {task_id+1}. Moving on without mastery. !!!\")\n",
        "\n",
        "    # --- Record metrics for this task ---\n",
        "    mgp_results['epochs_per_task'].append(epoch_count)\n",
        "    final_task_accuracy = evaluate_on_seen_tasks(model_mgp, split_cifar100_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "    mgp_results['accuracies'].append(final_task_accuracy)\n",
        "    print(f\"----- Overall Accuracy after Task {task_id+1}: {final_task_accuracy:.2f}% -----\")\n",
        "\n",
        "# --- ANALYSIS, SAVING, AND PLOTTING ---\n",
        "mgp_cifar_results_path = os.path.join(CONFIG['results_path'], 'mgp_cifar.pkl')\n",
        "save_results(mgp_results, mgp_cifar_results_path)\n",
        "print(f\"MGP (CIFAR-100) results saved to {mgp_cifar_results_path}\")\n",
        "\n",
        "baseline_cifar_results = load_results(os.path.join(CONFIG['results_path'], 'baseline_cifar_results.pkl'))\n",
        "er_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'er_cifar_buffer_tuning_ALL.pkl'))[1000] # Use your best ER run\n",
        "tfc_cifar_best_results = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_cifar_best.pkl'))\n",
        "\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_cifar_results['accuracies'],\n",
        "    'Standard ER': er_cifar_best_results['accuracies'],\n",
        "    'TFC-SR': tfc_cifar_best_results['accuracies'],\n",
        "    f'MGP': mgp_results['accuracies'],\n",
        "}\n",
        "plot_results(results_to_plot, title=\"MGP vs Other methods on Split CIFAR-100\")\n",
        "\n",
        "# --- Final Results ---\n",
        "mgp_results['final_accuracy'] = mgp_results['accuracies'][-1]\n",
        "print(\"\\n--- MGP Experiment Finished ---\")\n",
        "print(f\"Final Overall Accuracy: {mgp_results['final_accuracy']:.2f}%\")\n",
        "print(f\"Epochs taken per task: {mgp_results['epochs_per_task']}\")"
      ],
      "metadata": {
        "id": "VSKq2ntLfq0B"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# 70: 11.21"
      ],
      "metadata": {
        "id": "goV8hS3Pjoj5"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}