\section{Introduction}
Large datasets containing millions of samples have become the standard for obtaining advanced models in many artificial intelligence directions, including natural language processing~\cite{brown2020language}, speech recognition~\cite{baevski2020wav2vec}, and computer vision~\cite{goyal2021self}.
Meanwhile, large datasets also raise some issues.
For example, data storage and preprocessing are becoming more and more difficult.
Also, expensive servers are needed to train models on these datasets, which is not friendly for low-resource environments~\cite{li2022tri}.
An effective way to solve these problems is data selection (coreset construction) which identifies representative training samples of large datasets~\cite{bachem2017practical}.
However, since some of the original data cannot be discarded, there is an upper limit on the compression rate of the data selection method.
\par
Recently, dataset distillation as an alternative method to data selection has attracted widespread attention~\cite{wang2018datasetdistillation}.
Dataset distillation is the task of synthesizing a small dataset that preserves most information of the original large dataset.
The algorithm of dataset distillation takes a sizeable real dataset as input and synthesizes a small distilled dataset.
Unlike the data selection method that uses actual data from the original dataset, dataset distillation generates synthetic data with a different distribution from the original one~\cite{dong2022privacy}.
Therefore, the dataset distillation method can distill the whole dataset into several images, or even only one image~\cite{sucholutsky2021soft}.
Dataset distillation has many application scenarios, such as privacy protection~\cite{li2020soft, song2022federated}, continual learning~\cite{wiewel2021soft, sangermano2022sample}, neural architecture search~\cite{such2020generative, zhao2021datasetcondensation}, etc.
\par
Since the dataset distillation task was first introduced in 2018 by Wang et al.~\cite{wang2018datasetdistillation}, it has gained increasing attention in the research community.
The original dataset distillation algorithm is based on meta-learning and optimizes the distilled images with gradient-based hyperparameter optimization.
Subsequently, many works have significantly improved the distillation performance with label distillation~\cite{bohdal2020flexible}, gradient matching~\cite{zhao2021datasetcondensation}, differentiable augmentation~\cite{zhao2021differentiatble}, and distribution/feature matching~\cite{zhao2021distribution, wang2022cafe}.
The recently proposed dataset distillation method by matching network parameters has been the new SOTA on several datasets~\cite{cazenavette2022dataset}.
However, a network usually has a large number of parameters.
And we found that a few parameters are difficult to match in the distillation process and harm the distillation performance, which could be improved.
\par
In this paper, we propose a new dataset distillation method using parameter pruning.
As one of the model pruning approaches, parameter pruning is often used for model compression and accelerated model training.
Here, we introduce parameter pruning into dataset distillation to remove the effect of difficult-to-match parameters.
The proposed method can synthesize more robust distilled datasets by pruning difficult-to-match parameters in the distillation process, improving the distillation and cross-architecture generalization performance.
Experimental results on two benchmark datasets and a real-world COVID-19 chest X-ray (CXR) dataset show the superiority of the proposed method to other SOTA dataset distillation methods.
\par
Our main contributions can be summarized as follows:
\begin{itemize}
    \item We propose a new dataset distillation method using parameter pruning, which can synthesize more robust distilled datasets and improve the distillation performance.
    \item The proposed method can outperform other SOTA dataset distillation methods on two benchmark datasets and have better performance in cross-architecture generalization.
    \item We verify the effectiveness of the proposed method in the real-world application on a COVID-19 CXR dataset.
\end{itemize}
\begin{figure*}[t]
        \centering
        \includegraphics[width=15cm]{Image/Method.png}
        \caption{Overview of the proposed method. Our method uses a teacher-student architecture, and the objective is to make the student network parameters $\tilde{\theta}'_{i,J}$ match the teacher network parameters $\theta'_{i+K}$. Our method can avoid the influence of the difficult-to-match parameters on the distilled dataset by pruning the parameters in teacher and student networks.}
        \label{fig1}
\end{figure*}
\section{Methodology}
An overview of the proposed method is shown in Fig.~\ref{fig1}.
Our method is constructed on a teacher-student architecture, and the objective is to make the student network parameters trained on the distilled dataset $\mathcal{D}_\textrm{distill}$ match the teacher network parameters trained on the original large dataset $\mathcal{D}_\textrm{original}$.
Our method consists of three stages, teacher-student architecture training, dataset distillation using parameter pruning, and optimized distilled dataset generation, which we will show details in the following subsections. 
\subsection{Teacher-Student Architecture Training}
First, we pre-train $N$ teacher networks on $\mathcal{D}_\textrm{original}$ and save their snapshot parameters at each epoch.
We define teacher parameters as time sequences of parameters $\{\theta_{i}\}^{I}_{0}$.
Meanwhile, student parameters are defined as  $\tilde{\theta}_{i}$ who are trained on the distilled dataset $\mathcal{D}_\textrm{distill}$ at each training step $i$.
At each distillation step, we first sample parameters from
one of the teacher parameters at a random step $i$ and use it to initialize student parameters as $\tilde{\theta}_{i}=\theta_{i}$.
We set an upper bound $I^{+}$ on the random step $i$ to ignore the less informative later parts of the teacher parameters.
And the number of updates for student parameters and teacher parameters are set to $J$ and $K$, where $J \ll K$.
For each student update $j$, we sample a minibatch $b_{i,j}$ from distilled dataset as follows:
\begin{equation}
b_{i,j} \thicksim \mathcal{D}_\textrm{distill},
\end{equation}
Then we perform $j$ updates on the student parameters $\tilde{\theta}$ using the cross-entropy loss $\ell$ as follows:
\begin{equation}
\tilde{\theta}_{i,j+1} = \tilde{\theta}_{i,j} - \alpha\nabla\ell(\mathcal{A}(b_{i,j});\tilde{\theta}_{i,j}),
\end{equation}
where $\alpha$ represents the trainable learning rate.
$\mathcal{A}$ represents a differentiable data augmentation module proposed in~\cite{zhao2021differentiatble}, which can improve the distillation performance.
\begin{algorithm}[t]
    \caption{Dataset Distillation using Parameter Pruning}    
    \label{alg1}
    \begin{algorithmic}[1]
    \REQUIRE 
    $\{\theta_{i}\}^{I}_{0}$ : teacher parameters trained on $\mathcal{D}$;
    $\alpha_{0}$: initial value for $\alpha$;
    $\mathcal{A}$: differentiable augmentation function;
    $\epsilon$: threshold for pruning;
    $T$: number of distillation step;
    $J$: number of updates for student network;
    $K$: number of updates for teacher network;
    $I^{+}$: maximum start epoch.
    \ENSURE
    optimized distilled dataset $\mathcal{D}^{\ast}_\textrm{distill}$ and
    learning rate $\alpha^{\ast}$.
    \\
    \STATE
    Initialize distilled dataset:
    $\mathcal{D}_\textrm{distill} \thicksim \mathcal{D}$
    \STATE
    Initialize trainable learning rate:
    $\alpha = \alpha_{0}$
    \FOR{each distillation step $t = 0$ to $T - 1$}
    \STATE
    Choose random start epoch $i < I^{+}$
    \STATE
    Initialize student network with teacher parameter: 
    $\tilde{\theta}_{i}=\theta_{i}$
    \FOR{each distillation step $j = 0$ to $J - 1$}
    \STATE
    Sample a minibatch of distilled dataset:
    $b_{i,j} \thicksim \mathcal{D}_\textrm{distill}$ 
    \STATE
    Update student network with cross-entropy loss:
    \STATE
    $\tilde{\theta}_{i,j+1} = \tilde{\theta}_{i,j} - \alpha\nabla\ell(\mathcal{A}(b_{i,j});\tilde{\theta}_{i,j})$
    \ENDFOR
    \IF{parameter similarity in $\tilde{\theta}_{i,J}$ and $\theta_{i+K}$ is less than $\epsilon$}
    \STATE
    Prune network parameters:
    \STATE
    $\tilde{\theta}'_{i,J}, \theta'_{i+K}, \theta'_{i} = \textrm{Prune}(\tilde{\theta}_{i,J}, \theta_{i+K}, \theta_{i})$
    \ENDIF
    \STATE
    Compute loss between pruned parameters:
    \STATE
    $\mathcal{L} = || \tilde{\theta}'_{i,J}-\theta'_{i+K} ||^{2}_{2} \,\,\,/\,\,\, || \theta'_{i}-\theta'_{i+K} ||^{2}_{2}$
    \STATE
    Update $\mathcal{D}_\textrm{distill}$ and $\alpha$ with respect to $\mathcal{L}$
    \ENDFOR
    \end{algorithmic}
\end{algorithm}
\begin{table*}[t]
    \footnotesize
    \centering
    \caption{Test results of different methods on CIFAR-10 and CIFAR-100.}
    \label{tab1}
    \begin{tabular}{l|c|cccccccc|cc}
    \hline
    & IPC & Random & Forgetting~\cite{toneva2019empirical} & Herding~\cite{chen2010super} & DSA~\cite{zhao2021differentiatble} & DM~\cite{zhao2021distribution} & CAFE~\cite{wang2022cafe} & MTT~\cite{cazenavette2022dataset} & Ours & Full Dataset\\\hline
    \multirow{3}*{CIFAR-10} & 1 & 14.4$\pm$2.0 & 13.5$\pm$1.2 & 21.5$\pm$1.2 & 28.8$\pm$0.7 &  26.0$\pm$0.8 & 31.6$\pm$0.8 & 46.3$\pm$0.8 & \bfseries{46.4$\pm$0.6} & \multirow{3}*{84.8$\pm$0.1} \\
    & 10 & 26.0$\pm$1.2 & 23.3$\pm$1.0 & 31.6$\pm$0.7 & 52.1$\pm$0.5 & 48.9$\pm$0.6 & 50.9$\pm$0.5 & 65.3$\pm$0.7 & \bfseries{65.5$\pm$0.3} &  \\
    & 50 & 43.4$\pm$1.0 & 23.3$\pm$1.1 & 40.4$\pm$0.6 & 60.6$\pm$0.5 & 63.0$\pm$0.4 & 62.3$\pm$0.4 & 71.6$\pm$0.2 & \bfseries{71.9$\pm$0.2} &  \\\hline
          
    \multirow{3}*{CIFAR-100} & 1 & 4.2$\pm$0.3 & 4.5$\pm$0.2 & 8.4$\pm$0.3 & 13.9$\pm$0.3 & 11.4$\pm$0.3 & 14.0$\pm$0.3 & 24.3$\pm$0.3 & \bfseries{24.6$\pm$0.1} & \multirow{3}*{56.2$\pm$0.3} \\
    & 10 & 14.6$\pm$0.5 & 15.1$\pm$0.3 & 17.3$\pm$0.3 & 32.3$\pm$0.3 & 29.7$\pm$0.3 & 31.5$\pm$0.2 & 40.1$\pm$0.4 & \bfseries{43.1$\pm$0.3} & \\
    & 50 & 30.0$\pm$0.4 & 30.5$\pm$0.3 & 33.7$\pm$0.5 & 42.8$\pm$0.4 & 43.6$\pm$0.4 & 42.9$\pm$0.2 & 47.7$\pm$0.2 & \bfseries{48.4$\pm$0.3} & \\\hline
    \end{tabular}
\end{table*}
\subsection{Dataset Distillation Using Parameter Pruning}
Next, we get the student parameters $\tilde{\theta}_{i,J}$ trained on the distilled dataset $\mathcal{D}_\textrm{distill}$ from $J$ updates after initializing the student network.
Meanwhile, we can get the teacher parameters $\theta_{i+K}$ trained on the original dataset $\mathcal{D}_\textrm{original}$ from $K$ updates, which are the known parameters that have been pre-trained.
If the similarity of parameters in $\tilde{\theta}_{i,J}$ and $\theta_{i+K}$ is less than a threshold $\epsilon$, these parameters are recognized as difficult-to-match parameters and are pruned as follows:
\begin{equation}
\tilde{\theta}'_{i,J}, \theta'_{i+K}, \theta'_{i} = \textrm{Prune}(\tilde{\theta}_{i,J}, \theta_{i+K}, \theta_{i}),
\end{equation}
where $\textrm{Prune}$ represents a function that transforms the parameters to a one-dimension vector and prunes the parameters under the threshold at each last distillation step.
By pruning difficult-to-match parameters in teacher and student networks, the proposed method can avoid the influence of these parameters on the distilled dataset, which can improve the distillation and cross-architecture generalization performance.
The final loss $\mathcal{L}$ calculates the normalized squared $L_{2}$ error between pruned student parameters $\tilde{\theta}'_{i,J}$ and teacher parameters $\theta'_{i+K}$ as follows:
\begin{equation}
\mathcal{L} = \frac{|| \tilde{\theta}'_{i,J}-\theta'_{i+K} ||^{2}_{2}} {|| \theta'_{i}-\theta'_{i+K} ||^{2}_{2}},
\end{equation}
where we normalize the $L_{2}$ error by the distance $\theta'_{i}-\theta'_{i+K}$ moved by the teacher so that we can still obtain proper supervision from the late training period of the teacher network even if it has converged.
In addition, the normalization eliminates cross-layer and neuronal differences in magnitude.
\subsection{Optimized Distilled Dataset Generation}
Finally, we minimize the loss $\mathcal{L}$ using momentum stochastic gradient descent (SGD) and backpropagate the gradients through all $J$ updates to the student network for updating the pixels of the distilled dataset $\mathcal{D}_\textrm{distill}$ and trainable learning rate $\alpha$.
Note that the process of searching the optimized learning rate $\alpha^{\ast}$ can act as an automatic adjustment for the number of student and teacher updates (i.e., hyperparameters $J$ and $K$).
The distillation process of the proposed method is summarized in Algorithm~\ref{alg1}.
After obtaining the optimized distilled dataset $\mathcal{D}^{\ast}_\textrm{distill}$, we can train different neural networks on it for efficiency and use for downstream tasks, such as continual learning and neural architecture search.
\section{Experiments}
In this section, we conduct three experiments to verify the effectiveness of the proposed method.
The experimental settings are shown in subsection 3.1.
Subsections 3.2, 3.3, and 3.4 show the results of benchmark comparison, cross-architecture generalization, and real-world dataset verification, respectively.
All of our experiments were conducted using the PyTorch framework with an NVIDIA RTX A6000 GPU.
\begin{figure}[t]
        \centering
        \includegraphics[width=8cm]{Image/CIFAR-10.png}
        \caption{Visualization results of the distilled CIFAR-10 dataset.}
        \label{fig2}
\end{figure}
\begin{table}[t]
    \footnotesize
    \centering
    \caption{Test results of different width KIP~\cite{nguyen2021kipimprovedresults} and our method on CIFAR-10 and CIFAR-100.}
    \label{tab2}
    \begin{tabular}{l|c|c|cc}
    \hline
    & IPC & KIP-1024 & KIP-128 & Ours-128 \\\hline
    \multirow{3}*{CIFAR-10} & 1 & 49.9 &  38.3 & \bfseries{46.4} \\
    & 10 & 62.7 & 57.6 & \bfseries{65.5} \\
    & 50 & 68.6 & 65.8 & \bfseries{71.9} \\\hline
          
    \multirow{3}*{CIFAR-100} & 1 & 15.7 & 18.2 & \bfseries{24.6} \\
    & 10 & 28.3 & 32.8 & \bfseries{43.1}\\
    & 50 & - & - & \bfseries{48.4} \\\hline
    \end{tabular}
\end{table}
\subsection{Experimental Settings}
We used two benchmark datasets (i.e., CIFAR-10 and CIFAR-100) in the experiments for comparison with other methods.
The resolution of images in CIFAR-10 and CIFAR-100 is 32 $\times$ 32.
We also used a COVID-19 CXR dataset~\cite{li2022self} for proving the effectiveness of our method in the real-world application.
The COVID-19 CXR dataset has four classes, including COVID-19 (3616 images), Lung Opacity (6012 images), Normal (10192 images), and Viral Pneumonia (1345 images).
Since CXR images have high resolutions (224 $\times$ 224), they are resized to 112 $\times$ 112 for rapid distillation.
\par
For comparative methods, we used three data selection methods, including random selection (Random), example forgetting (Forgetting)~\cite{toneva2019empirical}, and herding method (Herding)~\cite{chen2010super}.
Also, we used five SOTA dataset distillation methods, including Differentiable Siamese Augmentation (DSA)~\cite{zhao2021differentiatble}, Distribution Matching (DM)~\cite{zhao2021distribution}, Aligning Features (CAFE)~\cite{wang2022cafe}, Matching Training Trajectories (MTT)~\cite{cazenavette2022dataset} and Kernel Inducing Point (KIP)~\cite{nguyen2021kipimprovedresults}.
The network used in this study is a sample 128-width ConvNet~\cite{gidaris2018dynamic}, which is often used in current dataset distillation methods.
We conducted three experiments to verify the effectiveness of the proposed method, including benchmark comparison, cross-architecture generalization, and real-world dataset verification.
We found that pruning too many parameters would cause the model training to crash.
Hence parameter pruning threshold $\epsilon$ was set to 0.1, which performed well in all experiments.
All experimental results are average accuracy and standard deviation of five networks trained from scratch on the distilled dataset.
\subsection{Benchmark Comparison}
In this subsection, we verify the effectiveness of the proposed method by comparing it with other SOTA dataset distillation methods on two benchmark datasets, i.e., CIFAR-10 and CIFAR-100.
We employed zero-phase component analysis (ZCA) whitening with default parameters and used a 3-depth ConvNet the same as MTT~\cite{cazenavette2022dataset}.
We pre-trained 200 teacher networks (50 epochs per teacher) for the distillation process.
The number of distillation steps was set to 5,000.
And the number of images per class (IPC) was set to 1, 10, and 50, respectively.
For KIP~\cite{nguyen2021kipimprovedresults}, we used their original 1024-width ConvNet (KIP-1024) and 128-width ConvNet (KIP-128) for a fair comparison.
Also, we used their custom ZCA implementation for distillation and evaluation.
\par
From Table~\ref{tab1}, we can see that the proposed method outperformed dataset selection methods and SOTA dataset distillation methods in all settings.
Especially for CIFAR-100 with IPC = 10, our method has an accuracy increased by 3.0\% compared to the second best method MTT.
As shown in Table~\ref{tab2}, the proposed method drastically outperformed KIP using the same 128-width ConvNet.
Even for KIP that uses 1024-width ConvNet, our method has higher accuracy except for CIFAR-10 with 1 image per class.
For the results of CIFAR-100 with IPC = 50, KIP did not conduct experiments due to the large computational resources and time required, so we only report our results in this paper.
Figure~\ref{fig2} shows visualization results of the distilled CIFAR-10 dataset.
As shown in Fig.~\ref{fig2}, when we set the number of distilled images to 1, the generated images were more abstract but also more information-dense because all information of a class has to be compressed into only one image in the distillation process.
Meanwhile, when the number of distilled images was set to 10, the generated images were more realistic and contained various forms because discriminative features in a class can be compressed into multiple images in the distillation process.
For example, we can see various types of dogs and different colored cars.
\begin{table}[t]
    \footnotesize
    \centering
    \caption{Cross-architecture generalization results on CIFAR-10 dataset with IPC = 10.}
    \label{tab3}
    \begin{tabular}{lcccc}
    \hline
    Architecture & ConvNet & AlexNet & VGG & ResNet \\\hline
    Ours & \bfseries{65.4$\pm$0.4} & \bfseries{35.8$\pm$1.3} & \bfseries{52.9$\pm$0.9} & \bfseries{51.8$\pm$1.1} \\
    MTT~\cite{cazenavette2022dataset} & 64.3$\pm$0.7 & 34.2$\pm$2.6 & 50.3$\pm$0.8 & 46.4$\pm$0.6 \\
   
    KIP~\cite{nguyen2021kipimprovedresults} & 47.6$\pm$0.9 & 24.4$\pm$3.9 & 42.1$\pm$0.4 & 36.8$\pm$1.0 \\\hline
    \end{tabular}
\end{table}
\begin{table}[t]
    \footnotesize
    \centering
    \caption{Test results on a COVID-19 CXR dataset when using different numbers of distilled images.}
    \label{tab4}
    \begin{tabular}{lcccccc}
    \hline
    IPC & 1 & 5 & 10 & 20 \\\hline
    Ours & \bfseries{54.2$\pm$3.7} & \bfseries{81.6$\pm$0.3} & \bfseries{83.5$\pm$0.2} & \bfseries{84.1$\pm$0.6} \\
    MTT~\cite{cazenavette2022dataset} & 52.5$\pm$5.5 & 79.3$\pm$0.4 & 82.2$\pm$0.2 & 82.7$\pm$0.5 \\\hline
    \end{tabular}
\end{table}
\begin{figure}[t]
        \centering
        \includegraphics[width=7.8cm]{Image/COVID.png}
        \caption{Visualization results of real and distilled CXR images.}
        \label{fig3}
\end{figure}
\subsection{Cross-Architecture Generalization}
In this subsection, we verify the effectiveness of our method in cross-architecture generalization.
Cross-architecture means using distilled images generated by one architecture and testing on other architectures.
The distilled images were generated by ConvNet on CIFAR-10 and the number of distilled images was set to 10.
We used the same pre-trained teacher networks used in subsection 3.2 for rapid distillation and experimentation.
For KIP, we used 128-width ConvNet and their custom ZCA implementation for distillation and evaluation.
And we tested the accuracy of ConvNet and three cornerstone networks for evaluation of cross-architecture generalization, i.e., AlexNet~\cite{krizhevsky2012imagenet}, VGG~\cite{simonyan2015very}, and ResNet~\cite{he2016deep}.
\par
From Table~\ref{tab3}, we can see that our method outperformed the SOTA methods MTT and KIP with all architectures.
Especially for ResNet, our method has increased accuracy by 5.2\% compared to MTT.
The results indicate that our method generated more robust distilled images.
By pruning difficult-to-match parameters in teacher and student networks, the proposed method can avoid the influence of these parameters on the distilled dataset, which improves cross-architecture generalization performance.
\subsection{Real-World Dataset Verification}
In this subsection, we verify the effectiveness of the proposed method in real-world application on a COVID-19 CXR dataset.
We used a 5-depth ConvNet for distillation since the image resolution increases significantly compared to CIFAR-10 and CIFAR-100.
We pre-trained 100 teacher networks (50 epochs per teacher) for the distillation process.
The number of distillation steps was set to 5,000.
And we tested the COVID-19 accuracy when the IPC was set to 1, 5, 10, and 20, respectively.
\par
Table~\ref{tab4} shows that the proposed method achieved high test accuracy even when using a few distilled CXR images, such as IPC = 20 (80 distilled CXR images).
Furthermore, the proposed method outperformed the SOTA method MTT in all IPC settings, indicating the effectiveness of our method in the real-world application for COVID-19 detection.
Figure~\ref{fig3} shows visualization results of real and distilled CXR images.
The distilled CXR images are generated from noise and have different distributions from the original images.
Compared with the original CXR images, the distilled images are visually different, showing the potential of dataset distillation for anonymization and privacy preservation.
\section{Conclusion}
This paper has proposed a novel dataset distillation method using parameter pruning.
The proposed method can synthesize more robust distilled datasets by pruning difficult-to-match parameters in the distillation process.
Experimental results show that the proposed method can outperform other SOTA dataset distillation methods on two benchmark datasets and have better performance in cross-architecture generalization.
We also verify the effectiveness of our method in the real-world application on a COVID-19 CXR dataset.
\newpage
\bibliographystyle{IEEEbib}
