{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "791feaec",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "executionInfo": {
     "elapsed": 775,
     "status": "ok",
     "timestamp": 1725262324908,
     "user": {
      "displayName": "",
      "userId": "02254144524654019702"
     },
     "user_tz": -120
    },
    "id": "791feaec",
    "outputId": "201ee417-4779-4abb-ee95-98080b8eb42d"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "sys.version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55c30a4d",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 90308,
     "status": "ok",
     "timestamp": 1725262415594,
     "user": {
      "displayName": "",
      "userId": "02254144524654019702"
     },
     "user_tz": -120
    },
    "id": "55c30a4d",
    "outputId": "4c7e51a6-f64f-46ec-fafc-eae724266a36"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mounted at /content/drive\n",
      "Pytorch version: 2.4.0+cu121\n",
      "Device name: Tesla T4\n"
     ]
    }
   ],
   "source": [
    "from google.colab import drive\n",
    "import torch\n",
    "\n",
    "drive.mount('/content/drive', force_remount=True)\n",
    "path = \"/content/drive/My Drive/EC-GitHub/\"\n",
    "os.chdir(path)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "    print(f\"Pytorch version: {torch.__version__}\")\n",
    "    print(f\"Device name: {torch.cuda.get_device_name(0)}\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "    print(\"No GPU available.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "FnO7dCj-EaD2",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "collapsed": true,
    "executionInfo": {
     "elapsed": 5413,
     "status": "ok",
     "timestamp": 1725262421001,
     "user": {
      "displayName": "",
      "userId": "02254144524654019702"
     },
     "user_tz": -120
    },
    "id": "FnO7dCj-EaD2",
    "outputId": "d01bc096-d073-4728-f36e-4eedcd4420a2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting typo\n",
      "  Downloading typo-0.1.7.tar.gz (7.3 kB)\n",
      "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "Building wheels for collected packages: typo\n",
      "  Building wheel for typo (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for typo: filename=typo-0.1.7-py3-none-any.whl size=7111 sha256=ce076108d6d062d6ff071355454df642a7d7dfa806a5d439e8ba5a7c8bf8660b\n",
      "  Stored in directory: /root/.cache/pip/wheels/80/f2/47/161501eb72b5d8e9e81221005e53be7e7a7f500b4ba96cb400\n",
      "Successfully built typo\n",
      "Installing collected packages: typo\n",
      "Successfully installed typo-0.1.7\n"
     ]
    }
   ],
   "source": [
    "! pip install typo # typo library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c873e7",
   "metadata": {
    "id": "69c873e7"
   },
   "outputs": [],
   "source": [
    "# import libraries\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset, RandomSampler, SequentialSampler\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from transformers import AdamW, DistilBertTokenizerFast, DistilBertModel, BertTokenizerFast, BertModel\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "import time\n",
    "from scipy.stats import pointbiserialr, spearmanr, pearsonr\n",
    "\n",
    "# useful .py\n",
    "from settings import * # settings\n",
    "from dataset import * # data pre-processing\n",
    "from model import * # models\n",
    "from optimization import * # model training, evaluation\n",
    "from shift_generate import * # shift generation\n",
    "from shift_evaluate import * # shift evaluation\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "pd.set_option('display.max_columns', 500)\n",
    "pd.set_option('display.width', 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "JJ6_Cjf8ym8M",
   "metadata": {
    "id": "JJ6_Cjf8ym8M"
   },
   "source": [
    "**Tests**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "jl_6K_lDySPt",
   "metadata": {
    "executionInfo": {
     "elapsed": 657,
     "status": "ok",
     "timestamp": 1726132841085,
     "user": {
      "displayName": "",
      "userId": "02254144524654019702"
     },
     "user_tz": -120
    },
    "id": "jl_6K_lDySPt"
   },
   "outputs": [],
   "source": [
    "# set TARGET SIZE\n",
    "TARGET_SIZE = 1000\n",
    "\n",
    "# shift severity dictionary\n",
    "shift_severity = dict()\n",
    "shift_severity['orderSplit1'] = [0.1, 0.5, 0.9]\n",
    "shift_severity['orderSplit2'] = [0.1, 0.5, 0.9]\n",
    "shift_severity['orderSplit3'] = [0.1, 0.5, 0.9]\n",
    "shift_severity['emptyCategory'] = [0.1, 0.5, 0.9]\n",
    "shift_severity['typos'] = [5, 25, 50]\n",
    "shift_severity['seqLengthSplit1'] = [0.1, 0.5, 0.9]\n",
    "shift_severity['seqLengthSplit2'] = [0.1, 0.5, 0.9]\n",
    "shift_severity['cutText'] = [0.1, 0.5, 0.9]\n",
    "shift_severity['abbrev'] = [0.1, 0.5, 0.9]\n",
    "shift_severity['newClass'] = [0.1, 0.5, 0.9]\n",
    "\n",
    "# new class dictionary\n",
    "newClassData = dict()\n",
    "newClassData[\"cloth_4\"] = [\"rating\", \"cloth_5\"]\n",
    "newClassData[\"pet_4\"] = [\"adoption speed\", \"pet_5\"]\n",
    "newClassData[\"salary_5\"] = [\"Salary\", \"salary_6\"]\n",
    "newClassData[\"wine_10\"] = [\"Variety\", \"wine_100\"]\n",
    "newClassData[\"wine_100\"] = [\"Variety\", \"wine_200\"]\n",
    "\n",
    "# select DATASET\n",
    "for DATASET in [\"wine_100\"]: # choose in {\"cloth_4\", \"airbnb\", \"kick\", \"pet_4\", \"salary_5, \"wine_10\", \"wine_100\"}\n",
    "    FILENAME, categorical_var, numerical_var, text_var, MAX_LEN_QUANTILE, N_CLASSES, WEIGHT_DECAY, FACTOR, N_EPOCHS, split_val, CRITERION, N_SEED, DROPOUT= load_settings(dataset = DATASET)\n",
    "\n",
    "\n",
    "    # run for every architecture\n",
    "    for MODEL_TYPE in [\"AllTextBERT\", \"LateFuseBERT\"]:\n",
    "\n",
    "          # performance records\n",
    "          perf_results = pd.DataFrame()\n",
    "          i =   0\n",
    "          local_perf_data = pd.DataFrame()\n",
    "\n",
    "          # load and prepare dataset\n",
    "          df_original = preprocess_dataset(DATASET, MODEL_TYPE)\n",
    "\n",
    "          # run for every seed\n",
    "          for SEED in range(N_SEED):\n",
    "\n",
    "                print(\"SEED:\", SEED)\n",
    "\n",
    "                # original dataset\n",
    "                df = df_original.copy()\n",
    "\n",
    "                # GPU or CPU\n",
    "                device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "                # control randomness\n",
    "                random.seed(SEED)\n",
    "                np.random.seed(SEED)\n",
    "                torch.manual_seed(SEED)\n",
    "                torch.cuda.manual_seed(SEED)\n",
    "\n",
    "                # Train/Test split\n",
    "                df, target = train_test_split(df, test_size = split_val, random_state = SEED)\n",
    "\n",
    "                # Source: text cleaning (keep only words and numbers)\n",
    "                df['clean_text'] = df[text_var].apply(lambda row:clean_text(row))\n",
    "\n",
    "                # Load the specific tokenizer\n",
    "                tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)\n",
    "\n",
    "                # text max length\n",
    "                MAX_LEN = int(np.quantile(df.apply(lambda row : len(tokenizer(row['clean_text']).input_ids), axis=1).values, q = [MAX_LEN_QUANTILE]).item())\n",
    "                MAX_LEN = min(MAX_LEN, 512) # maximum sequence length is 512 for BERT family\n",
    "\n",
    "                # Numerical variables pre-processing (Source)\n",
    "                numerical_var_scaled = [var + \" - scaled\" for var in numerical_var]\n",
    "                sc = StandardScaler()\n",
    "                df[numerical_var_scaled] = pd.DataFrame(sc.fit_transform(df[numerical_var]), columns = numerical_var_scaled).values\n",
    "\n",
    "                # Categorical variables pre-processing (Source)\n",
    "                oe = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)\n",
    "                reviews_cat_encoded = oe.fit_transform(df[categorical_var])\n",
    "                categorical_var_oe = []\n",
    "                for idx, var in enumerate(categorical_var):\n",
    "                    df[var+' - oe'] = reviews_cat_encoded[:,idx].astype(int) + 1 # add 1 so that unknown token is 0\n",
    "                    categorical_var_oe.append(var+' - oe')\n",
    "\n",
    "                # train / validation split\n",
    "                df_train, df_validation = train_test_split(df, test_size = split_val, random_state = SEED)\n",
    "                _, df_validation = train_test_split(df, test_size = min(len(df_validation), 5000), random_state = SEED) # maximum 5000 rows for Source\n",
    "\n",
    "                # prepare dataloader (Source)\n",
    "                _, BATCH_SIZE, _, _, _, _ = load_pretrained_settings()\n",
    "                dataset_validation = prepareTensorDatasetWithTokenizer(df_validation, \"clean_text\", categorical_var_oe, numerical_var_scaled, 'Y', tokenizer, MAX_LEN, special_tokens=True)\n",
    "                loader_validation = DataLoader(dataset_validation, sampler = SequentialSampler(dataset_validation),batch_size = BATCH_SIZE)\n",
    "\n",
    "                # load checkpoint (already trained and saved models from fine-tuning notebook)\n",
    "                model = torch.load('trained_models/'+MODEL_TYPE+'_Bert_'+DATASET+'_'+str(SEED)+'_checkpoint.pt')\n",
    "\n",
    "                # Temperature scaling\n",
    "                scaled_model = ModelWithTemperature(model)\n",
    "                scaled_model = scaled_model.set_temperature(loader_validation, device)\n",
    "\n",
    "                # extract CLS representation and softmax probabilities from classifier (Source)\n",
    "                source_cls, source_logits, source_softmax, source_scaled_softmax, source_y = cls_softmax_representations(model, scaled_model, MODEL_TYPE, loader_validation, device)\n",
    "                cls_var = [\"cls\"+str(i) for i in range(source_cls.shape[1])]\n",
    "                logits_var = [\"logits\"+str(i) for i in range(source_logits.shape[1])]\n",
    "                softmax_var = [\"p\"+str(i) for i in range(source_softmax.shape[1])]\n",
    "                scaled_softmax_var = [\"p_scaled\"+str(i) for i in range(source_softmax.shape[1])]\n",
    "                source_repres = pd.concat((pd.DataFrame(source_cls.cpu().numpy(), columns = cls_var),\n",
    "                                           pd.DataFrame(source_logits.cpu().numpy(), columns = logits_var),\n",
    "                                           pd.DataFrame(source_softmax.cpu().numpy(), columns = softmax_var),\n",
    "                                           pd.DataFrame(source_scaled_softmax.cpu().numpy(), columns = scaled_softmax_var),\n",
    "                                           pd.DataFrame(source_y.cpu().numpy(), columns=['y'])), axis=1)\n",
    "\n",
    "                # run for every type of shift for datasets \"cloth_4\", \"pet_4\", \"salary_5, \"wine_10\", \"wine_100\"\n",
    "                for shift_type in ['noShift', 'orderSplit1', 'orderSplit2', 'orderSplit3', 'emptyCategory', 'typos', 'seqLengthSplit1', 'seqLengthSplit2', 'cutText', 'abbrev', 'newClass',\n",
    "                                   'orderSplit1_typos', 'emptyCategory_typos', 'orderSplit2_seqLengthSplit2', 'orderSplit1_cutText', 'orderSplit3_abbrev']:\n",
    "\n",
    "                # # run for every type of shift for datasets \"airbnb\", \"kick\"\n",
    "                # for shift_type in ['noShift', 'orderSplit1', 'orderSplit2', 'orderSplit3', 'emptyCategory', 'typos', 'seqLengthSplit1', 'seqLengthSplit2', 'cutText', 'abbrev',\n",
    "                #                    'orderSplit1_typos', 'emptyCategory_typos', 'orderSplit2_seqLengthSplit2', 'orderSplit1_cutText', 'orderSplit3_abbrev']:\n",
    "\n",
    "                    # run for severity of shift\n",
    "                    for shift_intensity in [0, 1, 2]:\n",
    "\n",
    "                        # original Target dataset\n",
    "                        shift_target = target.copy()\n",
    "\n",
    "                        # reset text field as the original one\n",
    "                        if MODEL_TYPE == \"AllTextBERT\":\n",
    "                            shift_target[text_var] = shift_target[text_var+'_original'].copy()\n",
    "\n",
    "                        if shift_type == 'noShift':\n",
    "                            _, shift_target = train_test_split(shift_target, test_size = TARGET_SIZE, random_state = SEED)\n",
    "                            shift_target = shift_target.reset_index(drop=True)\n",
    "\n",
    "                        if shift_type in ['orderSplit1', 'orderSplit2', 'orderSplit3']:\n",
    "                            model_var_list = categorical_var + numerical_var\n",
    "                            np.random.seed(SEED + int(shift_type[-1]))\n",
    "                            variable = np.random.choice(model_var_list, size = 1, replace = False)[0]\n",
    "                            sample_rate = shift_severity[shift_type][shift_intensity]\n",
    "                            shift_target = orderSplit(shift_target, variable=variable, seed=SEED, target_size=TARGET_SIZE, sample_rate=sample_rate)\n",
    "\n",
    "                        if shift_type == 'emptyCategory':\n",
    "                            threshold = shift_severity[shift_type][shift_intensity]\n",
    "                            shift_target = emptyCategory(shift_target, variables = categorical_var, threshold = threshold, target_size = TARGET_SIZE, seed = SEED)\n",
    "\n",
    "                        if shift_type == 'typos':\n",
    "                            num_typos = shift_severity[shift_type][shift_intensity]\n",
    "                            shift_target = typos(shift_target, variable = text_var, num_typos = num_typos, target_size = TARGET_SIZE, seed = SEED)\n",
    "\n",
    "                        if shift_type in ['seqLengthSplit1', 'seqLengthSplit2']:\n",
    "                            sample_rate = shift_severity[shift_type][shift_intensity]\n",
    "                            shift_target = seqLengthSplit(data = shift_target, variable = text_var, seed = SEED, target_size = TARGET_SIZE, sample_rate = sample_rate, ascending = shift_type[-1] == '1')\n",
    "\n",
    "                        if shift_type == 'cutText':\n",
    "                            threshold = shift_severity[shift_type][shift_intensity]\n",
    "                            shift_target = cutText(shift_target, variable = text_var, cut_proportion = threshold, threshold = threshold, target_size = TARGET_SIZE, seed = SEED)\n",
    "\n",
    "                        if shift_type == 'abbrev':\n",
    "                          threshold = shift_severity[shift_type][shift_intensity]\n",
    "                          shift_target = abbrev(shift_target, variable=text_var, threshold = threshold, target_size=TARGET_SIZE, seed=SEED)\n",
    "\n",
    "                        if shift_type == 'newClass':\n",
    "                            sample_rate_new = shift_severity[shift_type][shift_intensity]\n",
    "                            label_name = newClassData[DATASET][0]\n",
    "                            new_dataset_name = newClassData[DATASET][1]\n",
    "                            shift_target = newClass(shift_target, label_name=label_name, text_var = text_var, new_dataset_name = new_dataset_name, model_type = MODEL_TYPE, seed=SEED, target_size=TARGET_SIZE, sample_rate_new=sample_rate_new)\n",
    "\n",
    "                        if shift_type == 'orderSplit1_typos':\n",
    "                            model_var_list = categorical_var + numerical_var\n",
    "                            np.random.seed(SEED + 1)\n",
    "                            variable = np.random.choice(model_var_list, size = 1, replace = False)[0]\n",
    "                            sample_rate = shift_severity['orderSplit1'][shift_intensity]\n",
    "                            shift_target = orderSplit(shift_target, variable=variable, seed=SEED, target_size=TARGET_SIZE + 1, sample_rate=sample_rate)\n",
    "                            num_typos = shift_severity['typos'][shift_intensity]\n",
    "                            shift_target = typos(shift_target, variable = text_var, num_typos = num_typos, target_size = TARGET_SIZE, seed = SEED)\n",
    "\n",
    "                        if shift_type == 'emptyCategory_typos':\n",
    "                            threshold = shift_severity['emptyCategory'][shift_intensity]\n",
    "                            shift_target = emptyCategory(shift_target, variables = categorical_var, threshold = threshold, target_size = TARGET_SIZE + 1, seed = SEED)\n",
    "                            num_typos = shift_severity['typos'][shift_intensity]\n",
    "                            shift_target = typos(shift_target, variable = text_var, num_typos = num_typos, target_size = TARGET_SIZE, seed = SEED)\n",
    "\n",
    "                        if shift_type == 'orderSplit2_seqLengthSplit2':\n",
    "                            model_var_list = categorical_var + numerical_var\n",
    "                            np.random.seed(SEED + 2)\n",
    "                            variable = np.random.choice(model_var_list, size = 1, replace = False)[0]\n",
    "                            sample_rate = shift_severity['orderSplit2'][shift_intensity]\n",
    "                            shift_target = orderSplit(shift_target, variable=variable, seed=SEED, target_size=TARGET_SIZE + 1, sample_rate=sample_rate)\n",
    "                            sample_rate = shift_severity['seqLengthSplit2'][shift_intensity]\n",
    "                            shift_target = seqLengthSplit(data = shift_target, variable = text_var, seed = SEED, target_size = TARGET_SIZE, sample_rate = sample_rate, ascending = False)\n",
    "\n",
    "                        if shift_type == 'orderSplit1_cutText':\n",
    "                            model_var_list = categorical_var + numerical_var\n",
    "                            np.random.seed(SEED + 1)\n",
    "                            variable = np.random.choice(model_var_list, size = 1, replace = False)[0]\n",
    "                            sample_rate = shift_severity['orderSplit1'][shift_intensity]\n",
    "                            shift_target = orderSplit(shift_target, variable=variable, seed=SEED, target_size=TARGET_SIZE + 1, sample_rate=sample_rate)\n",
    "                            threshold = shift_severity['cutText'][shift_intensity]\n",
    "                            shift_target = cutText(shift_target, variable = text_var, cut_proportion = threshold, threshold = threshold, target_size = TARGET_SIZE, seed = SEED)\n",
    "\n",
    "                        if shift_type == 'orderSplit3_abbrev':\n",
    "                            model_var_list = categorical_var + numerical_var\n",
    "                            np.random.seed(SEED + 3)\n",
    "                            variable = np.random.choice(model_var_list, size = 1, replace = False)[0]\n",
    "                            sample_rate = shift_severity['orderSplit3'][shift_intensity]\n",
    "                            shift_target = orderSplit(shift_target, variable=variable, seed=SEED, target_size=TARGET_SIZE + 1, sample_rate=sample_rate)\n",
    "                            threshold = shift_severity['abbrev'][shift_intensity]\n",
    "                            shift_target = abbrev(shift_target, variable=text_var, threshold = threshold, target_size=TARGET_SIZE, seed=SEED)\n",
    "\n",
    "                        # concatenate variables after implementing shifts\n",
    "                        if MODEL_TYPE == \"AllTextBERT\":\n",
    "                           shift_target = concat_tab_txt(shift_target, categorical_var + numerical_var, text_var)\n",
    "\n",
    "                        # text cleaning (keep only words and numbers)\n",
    "                        shift_target['clean_text'] = shift_target[text_var].apply(lambda row:clean_text(row))\n",
    "\n",
    "                        # Numerical variables pre-processing (Target)\n",
    "                        shift_target[numerical_var_scaled] = pd.DataFrame(sc.transform(shift_target[numerical_var]), columns = numerical_var_scaled).values\n",
    "\n",
    "                        # Categorical variables pre-processing (Target)\n",
    "                        target_cat_encoded = oe.transform(shift_target[categorical_var])\n",
    "                        for idx, var in enumerate(categorical_var):\n",
    "                            shift_target[var+' - oe'] = target_cat_encoded[:,idx].astype(int) + 1 # add 1 so that unknown token is 0\n",
    "\n",
    "                        # prepare dataloader (Target)\n",
    "                        dataset_target = prepareTensorDatasetWithTokenizer(shift_target, \"clean_text\", categorical_var_oe, numerical_var_scaled, 'Y', tokenizer, MAX_LEN, special_tokens=True)\n",
    "                        loader_target = DataLoader(dataset_target, sampler = SequentialSampler(dataset_target),batch_size = BATCH_SIZE)\n",
    "\n",
    "                        # model evaluation\n",
    "                        model.eval()\n",
    "                        target_perf = performance_pretrained(model, loader_target, MODEL_TYPE, SEED, device)\n",
    "\n",
    "                        # extract CLS representation and softmax probabilities from classifier (Target)\n",
    "                        target_cls, target_logits, target_softmax, target_scaled_softmax, target_y = cls_softmax_representations(model, scaled_model, MODEL_TYPE, loader_target, device)\n",
    "                        target_repres = pd.concat((pd.DataFrame(target_cls.cpu().numpy(), columns = cls_var),\n",
    "                                                   pd.DataFrame(target_logits.cpu().numpy(), columns = logits_var),\n",
    "                                                   pd.DataFrame(target_softmax.cpu().numpy(), columns = softmax_var),\n",
    "                                                   pd.DataFrame(target_scaled_softmax.cpu().numpy(), columns = scaled_softmax_var),\n",
    "                                                   pd.DataFrame(target_y.cpu().numpy(), columns=['y'])), axis=1)\n",
    "\n",
    "\n",
    "                        ## METHODS\n",
    "                        # 1. Jensen Shannon Distance (Lin, 1991)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"jsd\"\n",
    "                        perf_results.loc[i,\"global metric\"] = max_proba_jsd(source_repres[softmax_var], target_repres[softmax_var], num_bins = 10)\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "\n",
    "                        #2. Average confidence without probability scaling\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"ac\"\n",
    "                        ac_local_target, mean_uncertainty = average_confidence(target_repres, var_list = softmax_var)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_uncertainty\n",
    "                        local_error = np.argmax(ac_local_target[softmax_var], axis=1)!=ac_local_target[\"y\"]\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, ac_local_target[\"one_minus_max_proba\"])[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data = pd.DataFrame(local_error).astype(int)\n",
    "                        local_data[[\"dataset\",\"model type\",\"seed\",\"shift_type\", \"shift_intensity\"]] = [DATASET, MODEL_TYPE, SEED, shift_type,shift_intensity]\n",
    "                        local_data = local_data.rename({\"y\":\"error rate\"}, axis=1)\n",
    "                        local_data[\"local ac\"] =  ac_local_target[\"one_minus_max_proba\"]\n",
    "\n",
    "                        # 3. Average confidence with probability scaling\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"ac-scaled\"\n",
    "                        sac_local_target, sac_mean_uncertainty = average_confidence(target_repres, var_list = scaled_softmax_var)\n",
    "                        perf_results.loc[i,\"global metric\"] = sac_mean_uncertainty\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error,sac_local_target[\"one_minus_max_proba\"])[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local ac-scaled\"] =  sac_local_target[\"one_minus_max_proba\"]\n",
    "\n",
    "                        # 4. Maximum Mean Discrepancy (Gretton et al., JMLR 2012)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"mmd\"\n",
    "                        mmd_stat = mmd(source_cls, target_cls, device)\n",
    "                        perf_results.loc[i,\"global metric\"] = mmd_stat.item()\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "\n",
    "                        # 5. Difference Of Confidence (DOC) (Guillory et al., ICCV 2021)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"doc\"\n",
    "                        perf_results.loc[i,\"global metric\"] = doc(model, loader_validation, MODEL_TYPE, source_repres[softmax_var], target_repres[softmax_var], SEED, device)\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "\n",
    "                        # 6. Average Thresholded Confidence (ATC) (Garg et al., 2022)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"atc\"\n",
    "                        perf_results.loc[i,\"global metric\"] = atc(source_repres['y'], source_repres[softmax_var], target_repres[softmax_var])\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "\n",
    "                        # 7. Mandoline (Chen et al., ICML 2021)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"mandoline\"\n",
    "                        perf_results.loc[i,\"global metric\"] = mandoline_performance(source_repres['y'], source_repres[softmax_var], target_repres[softmax_var])\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "\n",
    "                        # 8. Monte Carlo Dropout (Gal et al., ICML 2016)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"mcd\"\n",
    "                        shannon_entropy, shannon_entropy_mean = compute_MCD(model, loader_target, n_simu = 5, seed = SEED, device = device)\n",
    "                        perf_results.loc[i,\"global metric\"] = shannon_entropy_mean\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, shannon_entropy)[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local mcd\"] =  shannon_entropy\n",
    "\n",
    "                        # 9. Domain classifier\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"dc\"\n",
    "                        source_with_weights, target_with_domain_proba, mean_auroc = domain_classifier(source_repres, target_repres, cls_var, SEED)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_auroc\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, target_with_domain_proba[\"domain proba\"])[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local dc\"] =  target_with_domain_proba[\"domain proba\"]\n",
    "\n",
    "                        # 10. Conformal prediction (Sadinle et al., 2019) + (Tibshirani et al., 2019)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"cp\"\n",
    "                        weights = source_with_weights[\"normalized weights\"]\n",
    "                        cp_local_target, mean_interval_width  = conformal_prediction(source_with_weights[softmax_var], source_with_weights['y'],\n",
    "                                                         target_with_domain_proba, softmax_var,\n",
    "                                                         target_coverage = 0.9, source_weights = weights)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_interval_width\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, cp_local_target[\"interval width\"])[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local cp\"] = cp_local_target[\"interval width\"]\n",
    "\n",
    "                        # 11. Error Classifier (xy)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"ec_xy\"\n",
    "                        ec, ec_target, mean_error_rate = error_classifier(source_with_weights, target_with_domain_proba, var_list = cls_var + softmax_var, softmax_var= softmax_var,\n",
    "                          source_weights = pd.Series(np.ones(source_with_weights.shape[0])), algo_type = \"rf\", seed = SEED)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_error_rate\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, ec_target[\"pred error\"])[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local ec_xy\"] = ec_target[\"pred error\"]\n",
    "\n",
    "                        # 12. Error Classifier (x)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"ec_x\"\n",
    "                        ec, ec_target, mean_error_rate = error_classifier(source_with_weights, target_with_domain_proba, var_list = cls_var, softmax_var= softmax_var,\n",
    "                          source_weights = pd.Series(np.ones(source_with_weights.shape[0])), algo_type = \"rf\", seed = SEED)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_error_rate\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, ec_target[\"pred error\"])[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local ec_x\"] = ec_target[\"pred error\"]\n",
    "\n",
    "                        # 13. Error Classifier (y)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"ec_y\"\n",
    "                        ec, ec_target, mean_error_rate = error_classifier(source_with_weights, target_with_domain_proba, var_list = softmax_var, softmax_var= softmax_var,\n",
    "                          source_weights = pd.Series(np.ones(source_with_weights.shape[0])), algo_type = \"rf\", seed = SEED)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_error_rate\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, ec_target[\"pred error\"])[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local ec_y\"] = ec_target[\"pred error\"]\n",
    "\n",
    "                        # 14. Deep Nearest Neighbors - 10 (Sun et al., ICML 2022)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"dnn_10\"\n",
    "                        dnn_target, mean_distance = deep_nearest_neighbors(10, source_repres, target_repres, var_list = cls_var)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_distance\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, dnn_target[\"target distances\"])[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local dnn_10\"] = dnn_target[\"target distances\"]\n",
    "\n",
    "                        # 15. Deep Nearest Neighbors - 100 (Sun et al., ICML 2022)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"dnn_100\"\n",
    "                        dnn_target, mean_distance = deep_nearest_neighbors(100, source_repres, target_repres, var_list = cls_var)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_distance\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, dnn_target[\"target distances\"])[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local dnn_100\"] = dnn_target[\"target distances\"]\n",
    "\n",
    "                        # 16. Energy score (Liu et al., NeurIPS 2020)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"energy\"\n",
    "                        energy_scores, mean_score = energy_score(target_logits, T = 1)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_score\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, energy_scores.cpu().numpy())[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local energy\"] = energy_scores.cpu().numpy()\n",
    "\n",
    "                        # 17. True Class Probability (Corbiere et al., NeurIPS 2019)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"tcp\"\n",
    "                        one_minus_tcp_preds, one_minus_tcp_mean = true_class_probability(source_cls, source_softmax, source_y, target_cls, SEED, device)\n",
    "                        perf_results.loc[i,\"global metric\"] = one_minus_tcp_mean\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, one_minus_tcp_preds)[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local tcp\"] = one_minus_tcp_preds\n",
    "\n",
    "                        # 18. Deep ensembles (Lakshminarayanan et al., NIPS 2017)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"deep_ens\"\n",
    "                        nn_model_list, shannon_entropy, mean_entropy = deep_ensembles(source_cls, source_y, target_cls,\n",
    "                                                                                      M = 5, output_shape = N_CLASSES, seed = SEED, device = device)\n",
    "                        perf_results.loc[i,\"global metric\"] = mean_entropy\n",
    "                        perf_results.loc[i,\"local metric\"] = pointbiserialr(local_error, shannon_entropy)[0]\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "                        local_data[\"local deep_ens\"] = shannon_entropy\n",
    "                        local_perf_data = pd.concat((local_perf_data, local_data))\n",
    "\n",
    "                        # 19. Projection norm (Yu et al., ICML 2022)\n",
    "                        start = time.time()\n",
    "                        perf_results.loc[i,[\"dataset\", \"Target size\", \"model type\", \"seed\",\"shift_type\",\"shift_intensity\", \"error rate (Target)\"]] = [DATASET, shift_target.shape[0], MODEL_TYPE, SEED, shift_type,shift_intensity, 1-target_perf]\n",
    "                        perf_results.loc[i,\"method\"] = \"pnorm\"\n",
    "                        diff = projection_norm(source_cls, source_y, target_cls, output_shape = N_CLASSES, seed = SEED, device = device)\n",
    "                        perf_results.loc[i,\"global metric\"] = diff\n",
    "                        perf_results.loc[i,\"time\"] = time.time() - start\n",
    "                        i+=1\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
