# -*- coding: utf-8 -*-
"""DDPM and EWC Implementation

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/11QI-LIo7kxIButYp1vHwx7uZl_VEm3zX

# EWC-Guided Diffusion Replay for Exemplar-Free Continual Learning

This notebook implements a hybrid continual learning framework that combines generative replay via Denoising Diffusion Probabilistic Models (DDPMs) and synaptic regularisation via Elastic Weight Consolidation (EWC). The objective is to mitigate catastrophic forgetting across sequential classification tasks on the **BloodMNIST** dataset (from MedMNIST).

### Key Components
- **DDPM-based Replay**: At each new task, we sample from a learned generative model trained on prior tasks, enabling pseudo-rehearsal without storing raw data.
- **EWC Regularisation**: Estimates the importance of network parameters via the Fisher Information Matrix and penalises deviation from previously important weights.
- **Vision Transformer Backbone**: We use a ViT classifier with task-specific heads to evaluate retention and plasticity under continual learning pressure.

### Task Setup
- Data is split into a sequence of classification tasks.
- After training on each task, DDPM is used to generate samples from past tasks, and EWC constrains weight drift.
- Final evaluation includes per-task accuracy, average accuracy, and visualisation of replay fidelity.
"""

from google.colab import drive
drive.mount('/content/drive')

"""Hybrid Continual Learning with DDPM + EWC"""

Setup and Install
!pip install -q medmnist torchvision

import torch, random, numpy as np
from medmnist import INFO

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# SECTION 2: Configuration
config = {
    'batch_size': 128,
    'epochs': 30,
    'lr': 3e-4,
    'lambda_ewc': 100,
    'replay_samples': 256,
    'ddpm_steps': 1000,
    'replay_ratio': 0.5,
    'img_size': 224,
    'patch_size': 16
}

# SECTION 3: Dataset Loader (e.g., BloodMNIST)
from medmnist import BloodMNIST
from torchvision import transforms
from torch.utils.data import DataLoader

def get_loader(dataset_cls, split, img_size):
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor()
    ])
    dataset = dataset_cls(split=split, download=True, transform=transform)
    loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=(split=="train"))
    return loader, len(dataset.info['label'])

# Example: Load BloodMNIST
train_loader, num_classes = get_loader(BloodMNIST, 'train', config['img_size'])
test_loader, _ = get_loader(BloodMNIST, 'test', config['img_size'])
print(f"Loaded BloodMNIST with {num_classes} classes")

"""Vision Transformer Classifier"""

import torch.nn as nn
from einops import rearrange

class ViTClassifier(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=10, dim=256, depth=4, heads=4, mlp_dim=512):
        super().__init__()
        assert img_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (img_size // patch_size) ** 2
        patch_dim = in_channels * patch_size * patch_size

        self.patch_size = patch_size
        self.to_patch_embedding = nn.Sequential(
            rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim, dim)
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim),
            num_layers=depth
        )
        self.to_cls_token = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        p = self.patch_size
        x = self.to_patch_embedding(x)
        b, n, _ = x.shape

        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding[:, :(n + 1)]
        x = self.transformer(x)
        x = self.to_cls_token(x[:, 0])
        return self.mlp_head(x)

# Instantiate
model = ViTClassifier(img_size=config['img_size'], patch_size=config['patch_size'], num_classes=num_classes).to(device)
print(model)

"""Class-Conditional DDPM"""

class DDPM(nn.Module):
    def __init__(self, num_classes, img_size=224, noise_steps=1000, beta_start=1e-4, beta_end=0.02):
        super().__init__()
        self.num_classes = num_classes
        self.noise_steps = noise_steps
        self.beta = torch.linspace(beta_start, beta_end, noise_steps)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

        self.model = nn.Sequential(
            nn.Conv2d(4, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )

    def forward(self, x, t, y_onehot):
        t_embed = t[:, None, None, None].float() / self.noise_steps
        t_embed = t_embed.expand_as(x[:, :1, :, :])
        x_input = torch.cat([x, t_embed, y_onehot], dim=1)
        return self.model(x_input)

    def sample(self, y_onehot, shape):
        x = torch.randn(shape).to(y_onehot.device)
        for t in reversed(range(self.noise_steps)):
            t_tensor = torch.full((shape[0],), t, device=x.device, dtype=torch.long)
            predicted_noise = self.forward(x, t_tensor, y_onehot)
            alpha = self.alpha[t]
            alpha_hat = self.alpha_hat[t]
            beta = self.beta[t]
            if t > 0:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
            x = 1 / torch.sqrt(alpha) * (x - (1 - alpha) / torch.sqrt(1 - alpha_hat) * predicted_noise) + torch.sqrt(beta) * noise
        return x

print("DDPM class loaded.")

"""Fisher Information and EWC"""

import copy

# Store Fisher Information and previous parameters
fisher_info = {}
prev_params = {}

def compute_fisher(model, data_loader, criterion, device):
    model.eval()
    fisher = {}
    for n, p in model.named_parameters():
        fisher[n] = torch.zeros_like(p)
    total_samples = 0

    for x, y in data_loader:
        x, y = x.to(device), y.to(device)
        model.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()

        for n, p in model.named_parameters():
            if p.grad is not None:
                fisher[n] += p.grad.detach() ** 2
        total_samples += 1

    for n in fisher:
        fisher[n] /= total_samples
    return fisher

def save_prev_params(model):
    return {n: p.detach().clone() for n, p in model.named_parameters()}

def ewc_loss(model, fisher, prev_params, lambda_ewc):
    loss = 0
    for n, p in model.named_parameters():
        if n in fisher:
            loss += (fisher[n] * (p - prev_params[n])**2).sum()
    return lambda_ewc * loss

""" Continual Training Loop with DDPM + EWC"""

import torch.nn.functional as F
from torchvision.utils import save_image

def train_on_task(task_id, model, ddpm, train_loader, replay_data, optimizer, criterion, config, device):
    model.train()
    for epoch in range(config['epochs']):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            # Prepare replay
            if replay_data and config['replay_ratio'] > 0:
                replay_x, replay_y = replay_data
                replay_x, replay_y = replay_x.to(device), replay_y.to(device)
                x = torch.cat([x, replay_x[:int(config['replay_ratio'] * len(x))]], dim=0)
                y = torch.cat([y, replay_y[:int(config['replay_ratio'] * len(y))]], dim=0)

            logits = model(x)
            loss = criterion(logits, y)

            # Add EWC loss
            if task_id > 0:
                loss += ewc_loss(model, fisher_info, prev_params, config['lambda_ewc'])

            loss.backward()
            optimizer.step()

        print(f"[Task {task_id} | Epoch {epoch+1}] Loss: {loss.item():.4f}")

def generate_replay_samples(ddpm, num_samples, num_classes, img_size, device):
    y = torch.randint(0, num_classes, (num_samples,), device=device)
    y_onehot = F.one_hot(y, num_classes).float().view(num_samples, num_classes, 1, 1).expand(-1, -1, img_size, img_size)
    samples = ddpm.sample(y_onehot, shape=(num_samples, 3, img_size, img_size))
    return samples, y

# Example training flow
from torch.optim import Adam

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=config['lr'])

replay_buffer = None  # No replay initially

for task_id, dataset_cls in enumerate([BloodMNIST]):  # Extend to task list
    print(f"\n--- Training on Task {task_id}: {dataset_cls.__name__} ---")

    train_loader, num_classes = get_loader(dataset_cls, 'train', config['img_size'])

    # Train main model
    train_on_task(task_id, model, ddpm, train_loader, replay_buffer, optimizer, criterion, config, device)

    # Update replay buffer
    replay_x, replay_y = generate_replay_samples(ddpm, config['replay_samples'], num_classes, config['img_size'], device)
    replay_buffer = (replay_x.detach(), replay_y.detach())

    # Update Fisher & Anchor
    fisher_info = compute_fisher(model, train_loader, criterion, device)
    prev_params = save_prev_params(model)

"""Task Definitions and Data Preparation"""

from medmnist import BloodMNIST, PathMNIST

task_classes = [BloodMNIST, PathMNIST]
task_names = ["BloodMNIST", "PathMNIST"]
task_loaders = {}
task_num_classes = {}

# Prepare loaders and class counts for both tasks
for i, task_cls in enumerate(task_classes):
    print(f"Preparing loaders for Task {i}: {task_names[i]}")
    train_loader, num_classes = get_loader(task_cls, 'train', config['img_size'])
    test_loader, _ = get_loader(task_cls, 'test', config['img_size'])
    task_loaders[i] = {'train': train_loader, 'test': test_loader}
    task_num_classes[i] = num_classes

"""Train on Task 1 (BloodMNIST) + Save DDPM + Fisher"""

task_id = 0  # Task 1: BloodMNIST
model = ViTClassifier(img_size=config['img_size'], patch_size=config['patch_size'],
                      num_classes=task_num_classes[task_id]).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
ddpm = DDPM(num_classes=task_num_classes[task_id], img_size=config['img_size']).to(device)
criterion = nn.CrossEntropyLoss()

# Train classifier on Task 1
train_on_task(task_id, model, ddpm, task_loaders[task_id]['train'], None,
              optimizer, criterion, config, device)

# Save Task 1 Fisher info + anchor
fisher_info = compute_fisher(model, task_loaders[task_id]['train'], criterion, device)
prev_params = save_prev_params(model)

# Store DDPM replay for Task 2
replay_x, replay_y = generate_replay_samples(ddpm, config['replay_samples'],
                                             task_num_classes[task_id], config['img_size'], device)
replay_buffers = {task_id: (replay_x.detach(), replay_y.detach())}

"""Train on Task 2 (PathMNIST) with Replay from Task 1 (BloodMNIST)"""

task_id = 1  # Task 2: PathMNIST
print(f"\n Training on Task 2: {task_names[task_id]} with replay from Task 1 ({task_names[0]})")

# Initialise model and optimizer for Task 2
model = ViTClassifier(img_size=config['img_size'], patch_size=config['patch_size'],
                      num_classes=task_num_classes[task_id]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])

# Use Task 1's replay
replay_x, replay_y = replay_buffers[0]  # from Task 1
replay_data = (replay_x, replay_y)

# Train classifier with Task 2 data + replay
train_on_task(task_id, model, ddpm, task_loaders[task_id]['train'], replay_data,
              optimizer, criterion, config, device)

# Evaluate Task 1 forgetting after Task 2 training
def evaluate(model, loader, name=""):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == y).sum().item()
            total += len(y)
    acc = correct / total
    print(f" Accuracy on {name}: {acc*100:.2f}%")
    return acc

# Accuracy on both tasks
acc_task1 = evaluate(model, task_loaders[0]['test'], name=task_names[0])
acc_task2 = evaluate(model, task_loaders[1]['test'], name=task_names[1])

# Forgetting on Task 1
# Assume acc_task1_star is the accuracy after Task 1 training (store it earlier)
# Here we'll mock it as placeholder; you should store real acc after Task 1 training
acc_task1_star = 0.85  # ← replace with real value
forgetting = acc_task1_star - acc_task1
print(f" Forgetting on Task 1 after Task 2: {forgetting:.4f}")

"""Cross-Domain Replay Transfer Evaluation"""

print("Cross-Domain Replay Transfer: T1 → T2")

# Step 1: Train DDPM on Task 1 (e.g., BloodMNIST)
task1_id = 0
ddpm_t1 = DDPM(num_classes=task_num_classes[task1_id], img_size=config['img_size']).to(device)
train_on_task(task1_id, model, ddpm_t1, task_loaders[task1_id]['train'], None,
              optimizer, criterion, config, device)

# Step 2: Generate synthetic samples from Task 1 DDPM (for use in Task 2)
task2_id = 1
replay_y_fake = torch.randint(0, task_num_classes[task2_id], (config['replay_samples'],), device=device)
replay_y_onehot = F.one_hot(replay_y_fake, task_num_classes[task2_id]).float().view(-1, task_num_classes[task2_id], 1, 1)
replay_y_onehot = replay_y_onehot.expand(-1, -1, config['img_size'], config['img_size'])
replay_x_fake = ddpm_t1.sample(replay_y_onehot, shape=(config['replay_samples'], 3, config['img_size'], config['img_size']))

# Step 3: Train classifier on Task 2 with/without cross-domain replay
def train_task2_with_cross_domain_replay(use_replay=True):
    model_task2 = ViTClassifier(img_size=config['img_size'], patch_size=config['patch_size'],
                                num_classes=task_num_classes[task2_id]).to(device)
    optimizer2 = torch.optim.Adam(model_task2.parameters(), lr=config['lr'])

    replay = (replay_x_fake, replay_y_fake) if use_replay else None

    train_on_task(task2_id, model_task2, ddpm_t1, task_loaders[task2_id]['train'], replay,
                  optimizer2, criterion, config, device)

    acc = evaluate(model_task2, task_loaders[task2_id]['test'], f"T2 {'with' if use_replay else 'without'} Replay")
    return acc

# Baseline (no replay)
acc_no_replay = train_task2_with_cross_domain_replay(use_replay=False)

# Cross-domain DDPM replay
acc_with_replay = train_task2_with_cross_domain_replay(use_replay=True)

print(f" Accuracy with Cross-Domain Replay: {acc_with_replay*100:.2f}% − {acc_no_replay*100:.2f}% = {(acc_with_replay - acc_no_replay)*100:.2f}%")

"""All 8 Tasks with DDPM+EWC"""

from medmnist import BreastMNIST, PneumoniaMNIST, RetinaMNIST, DermaMNIST, BloodMNIST, PathMNIST, AdrenalMNIST3D, SynapseMNIST3D

task_classes_full = [BreastMNIST, PneumoniaMNIST, RetinaMNIST, DermaMNIST,
                     BloodMNIST, PathMNIST, AdrenalMNIST3D, SynapseMNIST3D]

results = []

for task_id, dataset_cls in enumerate(task_classes_full):
    print(f"\n Task {task_id}: {dataset_cls.__name__}")

    train_loader, num_classes = get_loader(dataset_cls, 'train', config['img_size'])
    test_loader, _ = get_loader(dataset_cls, 'test', config['img_size'])

    model = ViTClassifier(img_size=config['img_size'], patch_size=config['patch_size'],
                          num_classes=num_classes).to(device)
    ddpm = DDPM(num_classes=num_classes, img_size=config['img_size']).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])

    # Train on task (standalone, no replay)
    train_on_task(task_id, model, ddpm, train_loader, None,
                  optimizer, criterion, config, device)

    # Store result
    acc = evaluate(model, test_loader, f"Task {task_id} - {dataset_cls.__name__}")
    results.append((dataset_cls.__name__, acc))

# Print final accuracy per task
print("\n Final Accuracy on All 8 Tasks:")
for name, acc in results:
    print(f"{name}: {acc*100:.2f}%")

"""Ablation Study"""

ablation_tasks = [BreastMNIST, PneumoniaMNIST, RetinaMNIST]
ablation_names = ["Breast", "Pneumonia", "Retina"]

def run_ablation(setting_name, use_replay=True, use_ewc=True):
    print(f"\n Running Setting: {setting_name}")
    results = []
    fisher_info = {}
    prev_params = {}
    replay_buffer = None

    model = ViTClassifier(img_size=config['img_size'], patch_size=config['patch_size'],
                          num_classes=3).to(device)  # Assume 3-class output (simplified)

    ddpm = DDPM(num_classes=3, img_size=config['img_size']).to(device)

    for task_id, dataset_cls in enumerate(ablation_tasks):
        train_loader, num_classes = get_loader(dataset_cls, 'train', config['img_size'])
        test_loader, _ = get_loader(dataset_cls, 'test', config['img_size'])
        optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])

        # Train with or without replay
        if not use_replay:
            replay_data = None
        else:
            replay_data = replay_buffer

        train_on_task(task_id, model, ddpm, train_loader, replay_data, optimizer,
                      criterion, config, device)

        # Store accuracy
        acc = evaluate(model, test_loader, f"{setting_name} - {ablation_names[task_id]}")
        results.append(acc)

        # Update Fisher + anchor (if using EWC)
        if use_ewc:
            fisher_info = compute_fisher(model, train_loader, criterion, device)
            prev_params = save_prev_params(model)

        # Generate replay (if using DDPM)
        if use_replay:
            replay_x, replay_y = generate_replay_samples(ddpm, config['replay_samples'], num_classes,
                                                         config['img_size'], device)
            replay_buffer = (replay_x.detach(), replay_y.detach())

    return results

# Run all three settings
acc_ewc_only = run_ablation("EWC only", use_replay=False, use_ewc=True)
acc_ddpm_only = run_ablation("DDPM only", use_replay=True, use_ewc=False)
acc_full = run_ablation("DDPM + EWC", use_replay=True, use_ewc=True)

#  Print table
print("\n Ablation Study Results:")
for i in range(len(ablation_tasks)):
    print(f"{ablation_names[i]} | EWC only: {acc_ewc_only[i]*100:.2f}% | DDPM only: {acc_ddpm_only[i]*100:.2f}% | Full: {acc_full[i]*100:.2f}%")

"""Qualitative Replay Grid + FID Estimation"""

from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt

def save_replay_grid(samples, title="Replay Samples", fname="replay_grid.png"):
    grid = make_grid(samples[:32], nrow=8, normalize=True)
    save_image(grid, fname)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title(title)
    plt.show()

# Visualise last replayed batch (from Task 3 in ablation)
save_replay_grid(replay_buffer[0], title="DDPM Replay after Task 3", fname="task3_replay.png")

!pip install pytorch-fid

from pytorch_fid import fid_score
import os
from torchvision.utils import save_image

def export_images_to_dir(images, outdir):
    os.makedirs(outdir, exist_ok=True)
    for i, img in enumerate(images):
        save_image(img, os.path.join(outdir, f"{i}.png"))

# Example usage
export_images_to_dir(replay_buffer[0], "fake_images/")
real_images, _ = next(iter(task_loaders[2]['train']))  # e.g., RetinaMNIST real
export_images_to_dir(real_images[:len(replay_buffer[0])], "real_images/")

# Compute FID
fid_value = fid_score.calculate_fid_given_paths(["real_images", "fake_images"],
                                                batch_size=32, device=device, dims=2048)
print(f" FID between real and DDPM-replayed samples: {fid_value:.2f}")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

# Load your CSV file
df = pd.read_csv("results.csv")  # path to your data
kl_vals = df["kl"].values
forgetting_vals = df["forgetting"].values
drift_vals = df["drift"].values

# --- 1. Correlation statistics ---
pearson_kl = stats.pearsonr(kl_vals, forgetting_vals)
pearson_drift = stats.pearsonr(drift_vals, forgetting_vals)

# --- 2. Joint regression model: F_k = a * KL + b * Drift ---
X = np.column_stack([kl_vals, drift_vals])
y = forgetting_vals
reg = LinearRegression().fit(X, y)
pred_forgetting = reg.predict(X)
r2 = r2_score(y, pred_forgetting)

# --- 3. Plot setup ---
plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 11,
    "axes.labelsize": 10,
    "axes.linewidth": 0.8,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
    "figure.dpi": 200,
})

fig, axes = plt.subplots(1, 3, figsize=(10.5, 3.4), constrained_layout=True)

# --- Panel 1: Replay divergence vs Forgetting ---
ax = axes[0]
ax.scatter(kl_vals, forgetting_vals, c="tab:blue", s=22, alpha=0.7, label="Tasks")
slope, intercept, _, _, _ = stats.linregress(kl_vals, forgetting_vals)
x_line = np.linspace(min(kl_vals), max(kl_vals), 100)
ax.plot(x_line, slope*x_line + intercept, color="black", lw=1)
ax.set_xlabel(r"$\widehat{\mathrm{KL}}(p_k \,\|\, \hat{p}_k)$")
ax.set_ylabel(r"Forgetting $F_k$")
ax.set_title("Replay divergence vs. forgetting")
ax.text(0.05, 0.92, f"Pearson r={pearson_kl[0]:.2f}\n(p={pearson_kl[1]:.3g})",
        transform=ax.transAxes, fontsize=8, bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"))

# --- Panel 2: Drift vs Forgetting ---
ax = axes[1]
ax.scatter(drift_vals, forgetting_vals, c="tab:green", s=22, alpha=0.7)
slope, intercept, _, _, _ = stats.linregress(drift_vals, forgetting_vals)
x_line = np.linspace(min(drift_vals), max(drift_vals), 100)
ax.plot(x_line, slope*x_line + intercept, color="black", lw=1)
ax.set_xlabel(r"$D_k = \sum_i F^{(k)}_i(\theta_{K,i}-\theta^\star_{k,i})^2$")
ax.set_ylabel(r"Forgetting $F_k$")
ax.set_title("Fisher-weighted drift vs. forgetting")
ax.text(0.05, 0.92, f"Pearson r={pearson_drift[0]:.2f}\n(p={pearson_drift[1]:.3g})",
        transform=ax.transAxes, fontsize=8, bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"))

# --- Panel 3: Predicted vs Observed ---
ax = axes[2]
ax.scatter(pred_forgetting, forgetting_vals, c="tab:red", s=22, alpha=0.7)
lims = [min(min(pred_forgetting), min(forgetting_vals)),
        max(max(pred_forgetting), max(forgetting_vals))]
ax.plot(lims, lims, linestyle="--", color="black", lw=1)
ax.set_xlim(lims)
ax.set_ylim(lims)
ax.set_xlabel(r"Predicted $\hat{F}_k$")
ax.set_ylabel(r"Observed $F_k$")
ax.set_title("Joint model")
ax.text(0.05, 0.92, f"$R^2$={r2:.2f}", transform=ax.transAxes,
        fontsize=8, bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"))

fig.suptitle("Empirical validation of the forgetting bound", y=1.02)

# Save outputs
fig.savefig("bound_validation.png", bbox_inches="tight")
fig.savefig("bound_validation.pdf", bbox_inches="tight")
plt.show()