{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "from collections import Counter\n",
    "sys.path.append('..')\n",
    "from unlabeled_extrapolation.datasets import domainnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "domains = ['real', 'sketch', 'painting', 'clipart']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "single_paths = {\n",
    "    'real': 'connectivity_checkpoints/single-domain/domainnet/real-sentry-swav-simclr--unlabeled_extrapolation-swav-checkpoints-domainnet_sentrytrue_real-clipart_queue500_epochs400_batchsize128_epsilon0.03_archresnet50_prototypes400-399',\n",
    "    'sketch': 'connectivity_checkpoints/single-domain/domainnet/sketch-sentry-swav-simclr--unlabeled_extrapolation-swav-checkpoints-domainnet_sentrytrue_sketch-painting_queue500_epochs400_batchsize128_epsilon0.03_archresnet50_prototypes400-399',\n",
    "    'painting': 'connectivity_checkpoints/single-domain/domainnet/painting-sentry-swav-simclr-unlabeled_extrapolation-swav-checkpoints-domainnet_sentrytrue_sketch-painting_queue500_epochs400_batchsize128_epsilon0.03_archresnet50_prototypes400-399',\n",
    "    'clipart': 'connectivity_checkpoints/single-domain/domainnet/clipart-sentry-swav-simclr-unlabeled_extrapolation-swav-checkpoints-domainnet_sentrytrue_real-clipart_queue500_epochs400_batchsize128_epsilon0.03_archresnet50_prototypes400-399'    \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "real_clipart = 'connectivity_checkpoints/between-domains/domainnet/real-clipart-sentry-swav-simclr-unlabeled_extrapolation-swav-checkpoints-domainnet_sentrytrue_real-clipart_queue500_epochs400_batchsize128_epsilon0.03_archresnet50_prototypes400-399'\n",
    "sketch_painting = 'connectivity_checkpoints/between-domains/domainnet/sketch-painting-sentry-swav-simclr-unlabeled_extrapolation-swav-checkpoints-domainnet_sentrytrue_sketch-painting_queue500_epochs400_batchsize128_epsilon0.03_archresnet50_prototypes400-399'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the off-limits classes for each domain\n",
    "# These numbers were selected visually by looking at the notebook `check_dataset_imbalance`\n",
    "size_cutoff = {\n",
    "    'real': 100,\n",
    "    'sketch': 100,\n",
    "    'painting': 100,\n",
    "    'clipart': 50\n",
    "}\n",
    "off_limits = {}\n",
    "for domain in domains:\n",
    "    domainnet_ds = domainnet.DomainNet(domain, split='train', version='sentry', verbose=False)\n",
    "    ys = [item[1] for item in domainnet_ds.data]\n",
    "    cntr = Counter(ys)\n",
    "    off = [cls for cls in cntr if cntr[cls] < size_cutoff[domain]]\n",
    "    off_limits[domain] = off"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "********** real **********\n",
      "Counter is: 92!\n",
      "********** sketch **********\n",
      "Counter is: 54!\n",
      "********** painting **********\n",
      "Counter is: 39!\n",
      "********** clipart **********\n",
      "Counter is: 67!\n"
     ]
    }
   ],
   "source": [
    "'''This is the same-domain-different-class connectivity.'''\n",
    "\n",
    "single_domain = {}\n",
    "for domain in domains:\n",
    "    print('*' * 10, domain, '*' * 10)\n",
    "    counter = 0\n",
    "    data = []\n",
    "    for file_name in os.listdir(single_paths[domain]):\n",
    "        if file_name.endswith('-final'):\n",
    "            cls_1 = file_name.split('-')[1]\n",
    "            cls_2 = file_name.split('-')[2]\n",
    "            if (cls_1 in off_limits[domain]) or (cls_2 in off_limits[domain]):\n",
    "                continue\n",
    "            counter += 1\n",
    "            data_dict = torch.load(os.path.join(single_paths[domain], file_name))\n",
    "            if len(data_dict['test_accs']) < 16:\n",
    "                print(file_name, len(data_dict['test_accs']))\n",
    "                continue\n",
    "            data.append(data_dict['test_accs'][-1])\n",
    "    single_domain[domain] = data\n",
    "    print(f'Counter is: {counter}!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'real': 96.92102261569703,\n",
       " 'sketch': 91.12436894784815,\n",
       " 'painting': 90.01875004356786,\n",
       " 'clipart': 93.88644212928429}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "single_domain_means = { key: np.mean(val) for key, val in single_domain.items() }\n",
    "single_domain_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Counter is: 30!\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "real-clipart\n",
    "'''\n",
    "\n",
    "counter = 0\n",
    "data = []\n",
    "for file_name in os.listdir(real_clipart):\n",
    "    if file_name.startswith('same-class') and file_name.endswith('-final'):\n",
    "        cls = file_name.split('-')[2]\n",
    "        if (cls in off_limits['clipart']) or (cls in off_limits['real']):\n",
    "            continue\n",
    "        counter += 1\n",
    "        data_dict = torch.load(os.path.join(real_clipart, file_name))\n",
    "        if len(data_dict['test_accs']) < 16:\n",
    "            print(file_name, len(data_dict['test_accs']))\n",
    "            continue\n",
    "        data.append(data_dict['test_accs'][-1])\n",
    "print(f'Counter is: {counter}!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "92.21181864314364"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Counter is: 17!\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "sketch-painting\n",
    "'''\n",
    "\n",
    "counter = 0\n",
    "data = []\n",
    "for file_name in os.listdir(sketch_painting):\n",
    "    if file_name.startswith('same-class') and file_name.endswith('-final'):\n",
    "        cls = file_name.split('-')[2]\n",
    "        if (cls in off_limits['painting']) or (cls in off_limits['sketch']):\n",
    "            continue\n",
    "        counter += 1\n",
    "        data_dict = torch.load(os.path.join(sketch_painting, file_name))\n",
    "        if len(data_dict['test_accs']) < 16:\n",
    "            print(file_name, len(data_dict['test_accs']))\n",
    "            continue\n",
    "        data.append(data_dict['test_accs'][-1])\n",
    "print(f'Counter is: {counter}!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "93.95930848663164"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "After fine-tuning\n",
    "'''\n",
    "\n",
    "single_paths = {\n",
    "    'real': 'connectivity_checkpoints/single-domain/domainnet/real-sentry-swav-simclr-unlabeled_extrapolation-swav-checkpoints-domainnet_sentrytrue_real-clipart_queue500_epochs400_batchsize128_epsilon0.03_archresnet50_prototypes400-source-real_ResNet50FS_16_source.pth/',\n",
    "    'clipart': 'connectivity_checkpoints/single-domain/domainnet/clipart-sentry-swav-simclr-unlabeled_extrapolation-swav-checkpoints-domainnet_sentrytrue_real-clipart_queue500_epochs400_batchsize128_epsilon0.03_archresnet50_prototypes400-source-real_ResNet50FS_16_source.pth/'\n",
    "}\n",
    "real_clipart = 'connectivity_checkpoints/between-domains/domainnet/real-clipart-sentry-swav-simclr-unlabeled_extrapolation-swav-checkpoints-domainnet_sentrytrue_real-clipart_queue500_epochs400_batchsize128_epsilon0.03_archresnet50_prototypes400-source-real_ResNet50FS_16_source.pth/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "********** real **********\n",
      "Counter is: 33!\n",
      "********** clipart **********\n",
      "Counter is: 67!\n"
     ]
    }
   ],
   "source": [
    "'''This is the same-domain-different-class connectivity.'''\n",
    "\n",
    "single_domain = {}\n",
    "for domain in ['real', 'clipart']:\n",
    "    print('*' * 10, domain, '*' * 10)\n",
    "    counter = 0\n",
    "    data = []\n",
    "    for file_name in os.listdir(single_paths[domain]):\n",
    "        if file_name.endswith('-final'):\n",
    "            cls_1 = file_name.split('-')[1]\n",
    "            cls_2 = file_name.split('-')[2]\n",
    "            if (cls_1 in off_limits[domain]) or (cls_2 in off_limits[domain]):\n",
    "                continue\n",
    "            counter += 1\n",
    "            data_dict = torch.load(os.path.join(single_paths[domain], file_name))\n",
    "            if len(data_dict['test_accs']) < 16:\n",
    "                print(file_name, len(data_dict['test_accs']))\n",
    "                continue\n",
    "            data.append(data_dict['test_accs'][-1])\n",
    "    single_domain[domain] = data\n",
    "    print(f'Counter is: {counter}!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'real': 96.45111714205225, 'clipart': 93.44073274871053}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "single_domain_means = { key: np.mean(val) for key, val in single_domain.items() }\n",
    "single_domain_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Counter is: 30!\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "real-clipart\n",
    "'''\n",
    "\n",
    "counter = 0\n",
    "data = []\n",
    "for file_name in os.listdir(real_clipart):\n",
    "    if file_name.startswith('same-class') and file_name.endswith('-final'):\n",
    "        cls = file_name.split('-')[2]\n",
    "        if (cls in off_limits['clipart']) or (cls in off_limits['real']):\n",
    "            continue\n",
    "        counter += 1\n",
    "        data_dict = torch.load(os.path.join(real_clipart, file_name))\n",
    "        if len(data_dict['test_accs']) < 16:\n",
    "            print(file_name, len(data_dict['test_accs']))\n",
    "            continue\n",
    "        data.append(data_dict['test_accs'][-1])\n",
    "print(f'Counter is: {counter}!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "90.88728390117633"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
