{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision \n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.transforms.functional as TF\n",
    "\n",
    "import pandas as pd \n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "import cvxpy as cvx\n",
    "import mosek\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics.pairwise import pairwise_kernels\n",
    "from random import sample\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from group import compute_group_coverages, compute_split_coverages \n",
    "from group import compute_group_qr_coverages, compute_cqr_coverages\n",
    "from Synthetic_data_generation import get_groups, generate_group_synthetic_data, generate_cqr_data\n",
    "from rkhs import compute_shifted_coverage, compute_qr_coverages, compute_adaptive_threshold\n",
    "from sklearn.metrics.pairwise import pairwise_kernels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from wilds import get_dataset\n",
    "from wilds.common.data_loaders import get_train_loader\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def strip_prefix(state_dict):\n",
    "    return_dict = {}\n",
    "    for key in state_dict:\n",
    "        #new_key = key.removeprefix('model.')\n",
    "        new_key = key[6:]\n",
    "        return_dict[new_key] = state_dict[key]\n",
    "\n",
    "    return return_dict\n",
    "def my_load(module, path, device=None):\n",
    "    if device is not None:\n",
    "        state = torch.load(path, map_location=device)\n",
    "    else:\n",
    "        state = torch.load(path)\n",
    "    state = strip_prefix(state['algorithm'])\n",
    "    \n",
    "    module.load_state_dict(state)\n",
    "    return \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=2048, out_features=1139, bias=True)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d_out = 1139\n",
    "\n",
    "constructor = torchvision.models.resnet50\n",
    "model = constructor()\n",
    "d_features = model.fc.in_features\n",
    "#last_layer = nn.Identity(d_features)\n",
    "#model.d_out = d_features\n",
    "last_layer = nn.Linear(d_features,d_out)\n",
    "model.d_out = d_features\n",
    "\n",
    "model.fc = last_layer\n",
    "\n",
    "#featurizer = model\n",
    "#classifier = nn.Linear(featurizer.d_out, d_out)\n",
    "#rx1Model = (featurizer, classifier)\n",
    "#rx1Model = nn.Sequential(*rx1model)\n",
    "\n",
    "rx1Model = model\n",
    "my_load(rx1Model,\n",
    "            '/Users/isaacgibbs/Documents/ConformalGans/Code/rxrx1_seed_0_epoch_best_model.pth',\n",
    "            device=torch.device(\"cpu\"))\n",
    "featurizer = rx1Model\n",
    "classifier = rx1Model.fc\n",
    "featurizer.fc = nn.Identity()\n",
    "\n",
    "rx1Model = nn.Sequential(*(featurizer,classifier))\n",
    "\n",
    "rx1Model.eval()\n",
    "featurizer.eval()\n",
    "classifier.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rx1Data = get_dataset(dataset=\"rxrx1\", download=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initialize_rxrx1_transform(is_training):\n",
    "    def standardize(x: torch.Tensor) -> torch.Tensor:\n",
    "        mean = x.mean(dim=(1, 2))\n",
    "        std = x.std(dim=(1, 2))\n",
    "        std[std == 0.] = 1.\n",
    "        return TF.normalize(x, mean, std)\n",
    "    t_standardize = transforms.Lambda(lambda x: standardize(x))\n",
    "\n",
    "    angles = [0, 90, 180, 270]\n",
    "    def random_rotation(x: torch.Tensor) -> torch.Tensor:\n",
    "        angle = angles[torch.randint(low=0, high=len(angles), size=(1,))]\n",
    "        if angle > 0:\n",
    "            x = TF.rotate(x, angle)\n",
    "        return x\n",
    "    t_random_rotation = transforms.Lambda(lambda x: random_rotation(x))\n",
    "\n",
    "    if is_training:\n",
    "        transforms_ls = [\n",
    "            t_random_rotation,\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.ToTensor(),\n",
    "            t_standardize,\n",
    "        ]\n",
    "    else:\n",
    "        transforms_ls = [\n",
    "            transforms.ToTensor(),\n",
    "            t_standardize,\n",
    "        ]\n",
    "    transform = transforms.Compose(transforms_ls)\n",
    "    return transform\n",
    "my_transform_train = initialize_rxrx1_transform(True)\n",
    "my_transform_eval = initialize_rxrx1_transform(False)\n",
    "\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "\n",
    "test_data_iid = rx1Data.get_subset(\n",
    "    \"id_test\",\n",
    "    transform = my_transform_train\n",
    ")\n",
    "test_data_ood = rx1Data.get_subset(\n",
    "    \"test\",\n",
    "    transform = my_transform_eval\n",
    ")\n",
    "#meta_data = pd.read_csv('data/rxrx1_v1.0/metadata.csv')\n",
    "#meta_test_iid = meta_data[meta_data['dataset'] == 'id_test']\n",
    "#meta_test_ood = meta_data[meta_data['dataset'] == 'test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "376.54104804992676\n",
      "torch.Size([40612, 3, 256, 256])\n",
      "323.8342959880829\n",
      "torch.Size([34432, 3, 256, 256])\n"
     ]
    }
   ],
   "source": [
    "from wilds.common.grouper import CombinatorialGrouper\n",
    "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n",
    "import time\n",
    "\n",
    "n_iid = 40612\n",
    "n_ood = 34432\n",
    "\n",
    "train_grouper = CombinatorialGrouper(\n",
    "            dataset=rx1Data,\n",
    "            groupby_fields=['experiment']\n",
    "        )\n",
    "myloaderTestIID = get_eval_loader(\n",
    "                loader='standard',\n",
    "                dataset=test_data_iid,\n",
    "                grouper=train_grouper,\n",
    "                batch_size=n_iid)\n",
    "\n",
    "myloaderTestOOD = get_eval_loader(\n",
    "                loader='standard',\n",
    "                dataset=test_data_ood,\n",
    "                grouper=train_grouper,\n",
    "                batch_size=n_ood)\n",
    "\n",
    "t0 = time.time()\n",
    "final_iid_test_data = next(iter(myloaderTestIID))\n",
    "t1 = time.time()\n",
    "print(t1-t0)\n",
    "print(final_iid_test_data[0].shape)\n",
    "\n",
    "t0 = time.time()\n",
    "final_ood_test_data = next(iter(myloaderTestOOD))\n",
    "t1 = time.time()\n",
    "print(t1-t0)\n",
    "print(final_ood_test_data[0].shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def move_to(obj, device):\n",
    "#     if isinstance(obj, dict):\n",
    "#         return {k: move_to(v, device) for k, v in obj.items()}\n",
    "#     elif isinstance(obj, list):\n",
    "#         return [move_to(v, device) for v in obj]\n",
    "#     elif isinstance(obj, float) or isinstance(obj, int):\n",
    "#         return obj\n",
    "#     else:\n",
    "#         # Assume obj is a Tensor or other type\n",
    "#         # (like Batch, for MolPCBA) that supports .to(device)\n",
    "#         return obj.to(device)\n",
    "def softmax(x):\n",
    "    e_x = np.exp(x - np.max(x))\n",
    "    return e_x / e_x.sum()\n",
    "\n",
    "def predictNN2(model,x):\n",
    "    #x = move_to(x, torch.device(\"cpu\"))\n",
    "    #nnOutput = model(x.reshape(1,3,64,64))[0].detach().numpy()\n",
    "#     x = move_to(x, torch.device(\"cpu\"))\n",
    "    nnOutput = model(x).detach().numpy()\n",
    "    return np.apply_along_axis(softmax,1,nnOutput)\n",
    "\n",
    "from random import sample\n",
    "train_points = sample(range(0,n_iid),100)\n",
    "test_points = sample(range(0,n_ood),100)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "train_x = final_iid_test_data[0][train_points,:,:,:].to(device)\n",
    "test_x = final_ood_test_data[0][test_points,:,:,:].to(device)\n",
    "featurizer.to(device)\n",
    "classifier.to(device)\n",
    "rx1Model.to(device)\n",
    "\n",
    "t0 = time.time()\n",
    "allFeatures_iid = featurizer(train_x)\n",
    "t1 = time.time()\n",
    "print(t1-t0)\n",
    "t0 = time.time()\n",
    "allProbs_iid = np.apply_along_axis(softmax,1,classifier(allFeatures_iid).detach().numpy())\n",
    "t1 = time.time()\n",
    "print(t1-t0)\n",
    "\n",
    "t0 = time.time()\n",
    "allFeatures_ood = featurizer(test_x)\n",
    "t1 = time.time()\n",
    "print(t1-t0)\n",
    "t0 = time.time()\n",
    "allProbs_ood = np.apply_along_axis(softmax,1,classifier(allFeatures_ood).detach().numpy())\n",
    "t1 = time.time()\n",
    "print(t1-t0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classConfScore(probs,y):\n",
    "    return sum(probs[probs > probs[y]])\n",
    "\n",
    "scores_train = np.zeros(len(train_points))\n",
    "for i in range(len(train_points)):\n",
    "    scores_train[i] = classConfScore(allProbs_iid[i,:],final_iid_test_data[1][train_points[i]])\n",
    "\n",
    "scores_test = np.zeros(len(test_points))\n",
    "for i in range(len(test_points)):\n",
    "    scores_test[i] = classConfScore(allProbs_ood[i,:],final_ood_test_data[1][test_points[i]])\\\n",
    "    \n",
    "q = np.quantile(scores_train,0.9)\n",
    "print(q)\n",
    "print(sum(scores_test > q)/len(scores_test))\n",
    "\n",
    "print(scores_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "3\n",
      "[4 6 9]\n",
      "[ 6 13]\n"
     ]
    }
   ],
   "source": [
    "a = np.array([[1,2,3],[3,4,6]])\n",
    "print(a[0,1])\n",
    "print(a[1,0])\n",
    "print(np.apply_along_axis(sum,0,a))\n",
    "print(np.apply_along_axis(sum,1,a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'allFeatures' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[0;32mIn [12]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m npAllFeatures_iid \u001b[38;5;241m=\u001b[39m allFeatures_iid\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[0;32m----> 2\u001b[0m meanMat \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mones(\u001b[43mallFeatures\u001b[49m\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m])\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m@\u001b[39m npAllFeatures_iid \u001b[38;5;241m@\u001b[39m np\u001b[38;5;241m.\u001b[39mones(allFeatures\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m      3\u001b[0m cented_featured_iid \u001b[38;5;241m=\u001b[39m npAllFeatures_iid \u001b[38;5;241m-\u001b[39m meanMat\n",
      "\u001b[0;31mNameError\u001b[0m: name 'allFeatures' is not defined"
     ]
    }
   ],
   "source": [
    "npAllFeatures_iid = allFeatures_iid.detach().numpy()\n",
    "meanMat = np.ones(allFeatures.shape[0]).reshape(-1,1) @ npAllFeatures_iid @ np.ones(allFeatures.shape[1])\n",
    "cented_featured_iid = npAllFeatures_iid - meanMat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
