{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "from config import cf\n",
    "from data import get_gradual_domains, get_dataset_shape\n",
    "from model import get_trained_model\n",
    "from utils import get_accuracy\n",
    "from utils.visual import *\n",
    "\n",
    "from method.utils import LabeledDataset, other2LabeledDataset\n",
    "from method.goat import get_encoded_dataset\n",
    "from method.ntk.ntk_mmd import get_ntk_mmd, NTK_MMD\n",
    "from method.ntk.tra_mmd import get_tra_mmd\n",
    "\n",
    "import importlib\n",
    "import utils.visual\n",
    "import method.ntk.ntk_mmd\n",
    "importlib.reload(utils.visual)\n",
    "importlib.reload(method.ntk.ntk_mmd)\n",
    "from utils.visual import dim_reduction, plot_2D_with_g, plot_line\n",
    "from method.ntk.ntk_mmd import get_ntk_mmd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## feature visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from config import cf, get_config\n",
    "from data import get_gradual_domains\n",
    "from model import get_trained_model\n",
    "from method import get_method\n",
    "\n",
    "print(f\"Using Device: {cf.device}\")\n",
    "from method.goat import get_encoded_dataset\n",
    "from method.ntk.ntk_mmd2 import cal_mmd\n",
    "from utils.visual import dim_reduction, plot_2D_with_g\n",
    "\n",
    "def main(data_name, model_name, method_name, domain_num, corruption=None, draw_init=True):\n",
    "    cf.data_name = data_name\n",
    "    cf.model_name = model_name\n",
    "    cf.method_name = method_name\n",
    "    cf.domain_num = domain_num\n",
    "    \n",
    "    gra_domains = get_gradual_domains(data_name, domain_num, corruption=corruption)\n",
    "    model = get_trained_model(data_name, model_name).to(cf.device)\n",
    "\n",
    "    if draw_init:\n",
    "        encoded_domain_s = get_encoded_dataset(model.encoder, gra_domains[0], cf.device, 64)\n",
    "        encoded_domain_t = get_encoded_dataset(model.encoder, gra_domains[-1], cf.device, 64)\n",
    "        dataset = torch.cat([encoded_domain_s.data, encoded_domain_t.data], dim=0)\n",
    "        labels = np.array([\"Source domain\"] * len(encoded_domain_s.data) +[\"Target domain\"] * len(encoded_domain_t.data))\n",
    "        _, gs, gt = cal_mmd(encoded_domain_s.data, encoded_domain_t.data, \"cpu\")\n",
    "        g = torch.cat([gs,gt])\n",
    "        red_dataset, red_labels, red_indexes = dim_reduction(dataset, labels, \"TSNE\", num_samples=5000, with_indexes=True)\n",
    "        plot_2D_with_g(red_dataset, red_labels, g[red_indexes], title=f\"NTK-MMD: {torch.mean(gs) + torch.mean(gt):.1f}\", \n",
    "                    fig_path=f\"init_embed_{data_name}_{model_name}_{corruption}\", ifsave=True)\n",
    "\n",
    "    method = get_method(method_name, model)\n",
    "    method.gradual_adapt(gra_domains)\n",
    "\n",
    "    encoded_domain_s = get_encoded_dataset(model.encoder, gra_domains[0], cf.device, 64)\n",
    "    encoded_domain_t = get_encoded_dataset(model.encoder, gra_domains[-1], cf.device, 64)\n",
    "    dataset = torch.cat([encoded_domain_s.data, encoded_domain_t.data], dim=0)\n",
    "    labels = np.array([\"Source domain\"] * len(encoded_domain_s.data) +[\"Target domain\"] * len(encoded_domain_t.data))\n",
    "    _, gs, gt = cal_mmd(encoded_domain_s.data, encoded_domain_t.data, \"cpu\")\n",
    "    g = torch.cat([gs,gt])\n",
    "    red_dataset, red_labels, red_indexes = dim_reduction(dataset, labels, \"TSNE\", num_samples=5000, with_indexes=True)\n",
    "    plot_2D_with_g(red_dataset, red_labels, g[red_indexes], title=f\"NTK-MMD: {torch.mean(gs) + torch.mean(gt):.1f}\", \n",
    "                   fig_path=f\"adapt_embed_{data_name}_{model_name}_{method_name}_{domain_num}_{corruption}\", ifsave=True)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main(\"rotate_mnist\", \"resnet\", \"GST\", 6, draw_init=True)\n",
    "main(\"rotate_mnist\", \"resnet\", \"NTK\", 6, draw_init=False)\n",
    "\n",
    "main(\"color_mnist\", \"resnet\", \"GST\", 6, draw_init=True)\n",
    "main(\"color_mnist\", \"resnet\", \"NTK\", 6, draw_init=False)\n",
    "\n",
    "main(\"portraits\", \"resnet\", \"GST\", 6, draw_init=True)\n",
    "main(\"portraits\", \"resnet\", \"NTK\", 6, draw_init=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corruptions = [\"gaussian_noise\", \"shot_noise\", \"impulse_noise\", \"defocus_blur\", \n",
    "\"glass_blur\", \"motion_blur\", \"zoom_blur\", \"snow\", \"frost\", \"fog\", \n",
    "\"brightness\", \"contrast\", \"elastic_transform\", \"pixelate\", \"jpeg_compression\"]\n",
    "methods = [\"GAS\"]\n",
    "datasets = [\"cifar10\", \"cifar100\"]\n",
    "\n",
    "for dataset in datasets:\n",
    "    for corruption in corruptions:\n",
    "        for method in methods:\n",
    "            main(dataset, \"default\", method, 6, corruption, method==\"GST\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get some exaples of the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corruptions = [\"gaussian_noise\", \"shot_noise\", \"impulse_noise\", \"defocus_blur\", \n",
    "\"glass_blur\", \"motion_blur\", \"zoom_blur\", \"snow\", \"frost\", \"fog\", \n",
    "\"brightness\", \"contrast\", \"elastic_transform\", \"pixelate\", \"jpeg_compression\"]\n",
    "\n",
    "indices = [50, 1009, 10000]\n",
    "\n",
    "for corruption in corruptions:\n",
    "    gra_domains = get_gradual_domains(\"imagenet\", 6, corruption=corruption)\n",
    "    for corr_idx, dataset in enumerate(gra_domains):\n",
    "        for idx in indices:\n",
    "            data, label = dataset[idx]\n",
    "            plt.imsave(f\"{idx}_{corruption}_{corr_idx}.png\", data.permute(1, 2, 0).cpu().numpy())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
