{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.distributed as dist\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import ImageFolder\n",
    "from PIL import Image\n",
    "import os\n",
    "import random\n",
    "import timm\n",
    "from timm.utils import accuracy\n",
    "import numpy as np\n",
    "import copy\n",
    "from utils.data_manager import DataManager\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "seed = 1\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "\n",
    "args=dict()\n",
    "args['sample_number'] = 200\n",
    "args['increment'] = 100\n",
    "\n",
    "data_manager=DataManager(dataset_name='imagenet1000', shuffle=True, seed=np.random.randint(100), \n",
    "    init_cls=args['increment'], increment=args['increment'], args=args)\n",
    "\n",
    "class_order = data_manager._class_order\n",
    "# class_mask = tuple([class_order[i:i+increment] for i in range(0, len(class_order), increment)])\n",
    "\n",
    "train_dataset = data_manager.get_dataset(class_order,source=\"train\", mode=\"train\")\n",
    "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "meta_model=timm.create_model('vit_base_patch16_224_in21k', pretrained=False, checkpoint_path=\"[your path]/ViT-B_16.npz\")\n",
    "model_infos=torch.load(\"[your path]/meta_epoch_10.pth\")\n",
    "\n",
    "filtered_model_infos = {k: v for k, v in model_infos.items()if not k.startswith(\"fc.\")}\n",
    "load_result = meta_model.load_state_dict(filtered_model_infos, strict=False)            \n",
    "# print(\"Missing keys:\", load_result.missing_keys)\n",
    "# print(\"Unexpected keys:\", load_result.unexpected_keys)\n",
    "# successful_keys = set(filtered_model_infos.keys()) - set(load_result.unexpected_keys)\n",
    "# print(\"Successfully loaded keys:\", successful_keys)\n",
    "meta_model.eval()\n",
    "meta_model.to(device)\n",
    "\n",
    "\n",
    "class_features = {}\n",
    "with torch.no_grad():\n",
    "    for _, images, labels in train_loader:\n",
    "        images = images.to(device)\n",
    "        # features = model(images).cpu().numpy()  \n",
    "        features = meta_model.forward_features(images)[:,0,:].cpu().numpy()\n",
    "        \n",
    "        for i, label in enumerate(labels):\n",
    "            if label.item() not in class_features:\n",
    "                class_features[label.item()] = []\n",
    "            class_features[label.item()].append(features[i])\n",
    "\n",
    "# 4. classes prototype \n",
    "class_avg_features = []\n",
    "for label in sorted(class_features.keys()):\n",
    "    avg_feature = np.mean(class_features[label], axis=0)\n",
    "    class_avg_features.append(avg_feature)\n",
    "class_avg_features = np.array(class_avg_features)  # shape: (1000, 768)\n",
    "\n",
    "# sorted_labels = sorted(class_features.keys())\n",
    "# class_avg_features = np.concatenate([np.array(class_features[label]) for label in sorted_labels], axis=0)\n",
    "\n",
    "# cov_estimator = EmpiricalCovariance(assume_centered=True)\n",
    "# cov_estimator.fit(class_avg_features)\n",
    "# cov_matrix1 = cov_estimator.covariance_  # shape: (768, 768)\n",
    "\n",
    "cov_matrix_meta10 = np.cov(class_avg_features, rowvar=False)\n",
    "\n",
    "save_path = '[your path]/cov_matrix_backbone10.npy'\n",
    "np.save(save_path, cov_matrix_meta10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
