{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "9eb723a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "from da_algo import *\n",
    "from ot_util import generate_domains\n",
    "from dataset import *\n",
    "import clip\n",
    "import copy\n",
    "import argparse\n",
    "import random\n",
    "from tqdm.notebook import tqdm,trange\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from torchvision.transforms import ToTensor\n",
    "from torch.utils.data import DataLoader\n",
    "from model import ENCODER,MLP,Classifier\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from train_model import test\n",
    "from torch import optim\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1a22822a",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b7bf4022",
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments import get_source_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ccd459c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser(description=\"GOAT experiments\")\n",
    "parser.add_argument(\"mode\", choices=[\"mnist\", \"cifar\", \"portraits\", \"office31\", \"office_home\"])\n",
    "parser.add_argument(\"--pretrain\", default=\"clip\", choices=[\"imagenet\", \"clip\", \"none\"])\n",
    "parser.add_argument(\"--model\", default=\"RN50\", choices=[\"RN50\", \"RN101\", \"ViT-B/32\", \"ViT-B/16\"])\n",
    "parser.add_argument(\"--gt-domains\", default=0, type=int)\n",
    "parser.add_argument(\"--generated-domains\", default=0, type=int)\n",
    "parser.add_argument(\"--seed\", default=None, type=int)\n",
    "parser.add_argument(\"--mnist-mode\", default=\"normal\", choices=[\"normal\", \"ablation\"])\n",
    "parser.add_argument(\"--rotation-angle\", default=45, type=int)\n",
    "parser.add_argument(\"--batch-size\", default=128, type=int)\n",
    "args = parser.parse_args(args='portraits --gt-domains 0'.split())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5f7fad1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "(src_tr_x, src_tr_y, src_val_x, src_val_y, inter_x, inter_y, dir_inter_x, dir_inter_y,\n",
    " trg_val_x, trg_val_y, trg_test_x, trg_test_y) = make_portraits_data(1000, 1000, 14000, 2000, 1000, 1000)\n",
    "tr_x, tr_y = np.concatenate([src_tr_x, src_val_x]), np.concatenate([src_tr_y, src_val_y])\n",
    "ts_x, ts_y = np.concatenate([trg_val_x, trg_test_x]), np.concatenate([trg_val_y, trg_test_y])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c7293d1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = ToTensor()\n",
    "src_trainset = EncodeDataset(tr_x, tr_y.astype(int), transforms)\n",
    "tgt_trainset = EncodeDataset(ts_x, ts_y.astype(int), transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 436,
   "id": "f9e79d42",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_domains(n_domains):\n",
    "    n2idx = {0:[], 1:[3], 2:[2,4], 3:[1,3,5], 4:[0,2,4,6]}\n",
    "    domain_set = []\n",
    "    domain_idx = n2idx[n_domains]\n",
    "    # for i in range(1, n_domains+1):\n",
    "    #     domain_idx.append(7 // (n_domains+1) * i)\n",
    "#     print('Given intermediate domains:',domain_idx)\n",
    "    for i in domain_idx:\n",
    "        start, end = i*2000, (i+1)*2000\n",
    "        domain_set.append(EncodeDataset(inter_x[start:end], inter_y[start:end].astype(int), ToTensor()))\n",
    "    return domain_set\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "id": "2b2e7687",
   "metadata": {},
   "outputs": [],
   "source": [
    "from CoVi.covi_trainer import train_covi\n",
    "from CoVi.args import get_covi_config\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "16b452e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = get_covi_config(args=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "babf34a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "src_train_loader = DataLoader(src_trainset, batch_size=config.batch_size, shuffle=True,drop_last=True,pin_memory=True)\n",
    "tgt_train_loader = DataLoader(tgt_trainset, batch_size=config.batch_size, shuffle=True,drop_last=True,pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "912c54c2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 504,
   "id": "b8deffa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_model_optimizer(model = None, encoder = None):\n",
    "    if model is None:\n",
    "        encoder = ENCODER().to(device) if encoder is None else encoder\n",
    "        model = Classifier(encoder, MLP(mode=args.mode, n_class=2, hidden=1024)).to(device)\n",
    "    else:\n",
    "        assert encoder is None\n",
    "        encoder = model.encoder\n",
    "        \n",
    "    n_channels = 32\n",
    "    emp_learner = nn.Sequential(\n",
    "                nn.Conv2d(n_channels * 2, n_channels, kernel_size=3, stride=1),\n",
    "                nn.ReLU(),\n",
    "                nn.Conv2d(in_channels=n_channels, out_channels=11, kernel_size=1),\n",
    "                nn.Flatten()\n",
    "            ).to(device)    \n",
    "    lr = 0.0001\n",
    "    l2_decay = 1e-3\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2_decay)\n",
    "    optimizer_emp = optim.Adam(list(emp_learner.parameters()), lr=lr, weight_decay=l2_decay)\n",
    "\n",
    "    return model,emp_learner,optimizer,optimizer_emp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 486,
   "id": "e9b87bdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from CoVi.utils import EntropyLoss\n",
    "entropy = EntropyLoss().to(device)\n",
    "cross_entropy = nn.CrossEntropyLoss().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 507,
   "id": "3acabd8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfs_acc = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 513,
   "id": "75e57594",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "130fc51296a34a7d990ea5bc92a9546d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "#Domains:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Seed:   0%|          | 0/5 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "config.batch_size = 128\n",
    "n_epochs = 10\n",
    "for gt_domains in tqdm([0,1,2,3,4],desc='#Domains'):\n",
    "    all_sets = get_domains(gt_domains)\n",
    "    all_sets = [src_trainset] + all_sets + [tgt_trainset]\n",
    "    df_acc = pd.DataFrame(columns = ['seed','domain','epoch','acc'])\n",
    "    for seed in trange(5,10,desc='Seed',leave=False):\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "        torch.cuda.manual_seed(seed)\n",
    "\n",
    "        model,emp_learner,optimizer,optimizer_emp = init_model_optimizer()\n",
    "        \n",
    "#         encoder = ENCODER().to(device)\n",
    "#         model = get_source_model(src_trainset, src_trainset, 2, mode=\"portraits\", encoder=encoder, epochs=20,verbose=False)\n",
    "\n",
    "        encoder = nn.Sequential(* (list(model.encoder.encode) + list(model.mlp.mlp[:4])))\n",
    "        mlp = nn.Sequential(*list(model.mlp.mlp[4:]))\n",
    "        model.encoder = encoder\n",
    "        model.mlp = mlp\n",
    "        model,emp_learner,optimizer,optimizer_emp = init_model_optimizer(model=model)\n",
    "\n",
    "        pseudolabel_sets = [None for S in all_sets]\n",
    "\n",
    "        # for p in encoder.parameters():\n",
    "        #     p.requires_grad = False\n",
    "        epochs = np.arange(n_epochs)\n",
    "        for i in range(len(all_sets)-1): #,desc='Domain'):\n",
    "            set_1 = all_sets[i] if i == 0 else pseudolabel_sets[i]\n",
    "            set_2 = all_sets[i+1]\n",
    "\n",
    "            loader_1 = DataLoader(set_1, batch_size=config.batch_size, shuffle=True,drop_last=True)\n",
    "            loader_2 = DataLoader(set_2, batch_size=config.batch_size, shuffle=True,drop_last=True) # labels are not used for training\n",
    "            accs = {}\n",
    "            pbar = epochs #,desc='Epoch',leave=False)\n",
    "            for epoch in pbar:\n",
    "                train_covi(config, loader_1, loader_2, \n",
    "                           optimizer, optimizer_emp, model, \n",
    "                           model.encoder, emp_learner, entropy, cross_entropy, )\n",
    "                if (epoch+1) % 2 == 0:\n",
    "                    acc = test(tgt_train_loader,model,verbose=False)\n",
    "                    accs[epoch+1] = acc\n",
    "#                 pbar.set_postfix(acc=acc)\n",
    "#             pbar.close()\n",
    "            pseudolabel_sets[i+1] = get_pseudolabel_set(set_2, model)\n",
    "            df_acc = pd.concat([df_acc, pd.DataFrame({'seed':seed, 'domain':i+1, 'epoch':list(accs.keys()), \n",
    "                                                      'acc':list(accs.values())})],ignore_index=True)\n",
    "    dfs_acc[gt_domains] = df_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 429,
   "id": "b0dded2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_acc = df_acc.reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 456,
   "id": "d0a67c84",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.stats as st\n",
    "# returns confidence interval of mean\n",
    "def mean_and_conf(a, conf=0.95):\n",
    "    mean, sem, m = np.mean(a), st.sem(a), st.t.ppf((1+conf)/2., len(a)-1)\n",
    "    return mean, m*sem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 518,
   "id": "7b4faef4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#domains=4 83.1 ± 1.9\n",
      "#domains=0 73.7 ± 3.5\n",
      "#domains=1 75.3 ± 1.8\n",
      "#domains=2 79.8 ± 3.0\n",
      "#domains=3 82.3 ± 1.4\n"
     ]
    }
   ],
   "source": [
    "for gt_domains, df_acc in dfs_acc.items():\n",
    "    acc_seeds = df_acc.loc[(df_acc.epoch == 10) & (df_acc.domain == gt_domains + 1)]['acc']\n",
    "    mean, conf = mean_and_conf(acc_seeds)\n",
    "    print(f'#given domains={gt_domains}', f'{round(mean,1)} ± {round(conf, 1)}' )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
