{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "#### Basic Assumptions\n",
    "- There are 2 domains (source, target) and $K$ classes.\n",
    "- Each (class, domain) pair $(c, d)$ has vertices $\\{ x_{c, d, i} \\}_{i=1}^n$, so that the union constitutes the entire input space: $\\{ x_{c, d, i} \\} = \\mathcal{X}$.\n",
    "    - Note that each vertex has a ground-truth class, but with augmentation, examples may augment into a different class.\n",
    "- For simplicity we assume a uniform distribution over the inputs, i.e. $P(x) = \\frac{1}{\\left|\\mathcal{X}\\right|}$.\n",
    "\n",
    "#### \"Stochastic block\" data model (quotes because it's not really stochastic block since it's a deterministic matrix)\n",
    "- The data augmentation of vertex $x$ is a distribution $A(\\cdot | x)$ over $\\mathcal{X}$ (so the natural and augmented data spaces are the same space).\n",
    "    - Note this means the graph is fully-connected.\n",
    "- Each $x \\in \\mathcal{X}$ has the same distribution of augmentation over the other _vertex types_:\n",
    "    - Each $A(x | x)$ is the same (\n",
    "- For each $x_1, x_2 \\in \\mathcal{X}$, the edge weight $w_{x_1 x_2} = E_{x \\sim P}\\left[A(x_1 | x) A(x_2 | x)\\right]$.\n",
    "- This is the ideal graph model, as it is perfectly deterministic (the randomness is baked into the augmentation distribution) and there's lots of symmetry.\n",
    "\n",
    "#### Random graph data model\n",
    "- Each graph is a random instantiation of the \"stochastic block\" data model. Let $M$ be the number of augmentations per vertex; then, select $\\{ x_m \\}_{m=1}^M$ i.i.d. from $A(\\cdot | x)$ and let $x$'s augmentations be a uniform distribution over those $M$ points.\n",
    "- I think in expectation, this random graph is equivalent to the stochastic block model and can be interpreted as a random instantiation.\n",
    "\n",
    "#### This notebook implements:\n",
    "- Generates an adjacency matrix according to either the stochastic block or random graph model.\n",
    "- Computes the SVD / features for that graph.\n",
    "    - Let W be the edge weight matrix, and let D be the diagonal matrix such that $D_{xx} = \\sum_{x' \\in \\mathcal{X}} w_{x x'}$. Then, the normalized adjacency matrix is $\\bar{W} = D^{-\\frac{1}{2}} W D^{-\\frac{1}{2}}$.\n",
    "    - Compute the top $4K$ eigenvectors of this adjacency matrix (at least 2x the number of clusters in the data distribution, according to spectral contrastive theory)\n",
    "        - If the graph has a single connected component, discard the top eigenvector since it is all-ones.\n",
    "- Trains a logistic classifier on the source features\n",
    "    - Teatures of augmented views for each $x$ are averaged according to $x$'s data augmentation, and that is considered an example (along with $x$'s class label).\n",
    "- Test on the target via the same ensembling method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulation functions (can skip all the way to \"illustrative example\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Utility functions for all graph simulation types\n",
    "'''\n",
    "\n",
    "def init_vertices(\n",
    "    num_classes,\n",
    "    num_domains,\n",
    "    vert_per_clus\n",
    "):\n",
    "    '''\n",
    "    This function returns 2 dictionaries:\n",
    "    1. vertex number -> (class, domain) membership\n",
    "    2. (class, domain) -> vertices that belong to that cluster\n",
    "    '''\n",
    "    vert_to_label, label_to_vert = {}, {}\n",
    "    idx = 0\n",
    "    for c in range(num_classes):\n",
    "        for d in range(num_domains):\n",
    "            label = (c, d)\n",
    "            verts = range(idx, idx + vert_per_clus)\n",
    "            label_to_vert[label] = list(verts)\n",
    "            vert_to_label.update({ i: label for i in verts})\n",
    "            idx += vert_per_clus\n",
    "    return vert_to_label, label_to_vert\n",
    "\n",
    "def get_connectivity(\n",
    "    vert_to_label,\n",
    "    self, # prob. weight for transitioning to self\n",
    "    same_cls_same_dom, # these 4 are probability weights for different types of transitions\n",
    "    same_cls_diff_dom,\n",
    "    diff_cls_same_dom,\n",
    "    diff_cls_diff_dom\n",
    "):\n",
    "    '''\n",
    "    This function returns a list of N = len(vert_to_label) probability distributions,\n",
    "    where the P_i(j) denotes the probability that vertex i transitions to vertex j\n",
    "    '''\n",
    "    \n",
    "    result = []\n",
    "    for vert, (cls, dom) in sorted(vert_to_label.items()): # \"current\" vertex\n",
    "        dist = []\n",
    "        # iterate through all possible neighbor vertices\n",
    "        for other_vert, (other_cls, other_dom) in sorted(vert_to_label.items()):\n",
    "            if vert == other_vert: # transition to self\n",
    "                dist.append(self)\n",
    "            elif cls == other_cls:\n",
    "                if dom == other_dom:\n",
    "                    dist.append(same_cls_same_dom)\n",
    "                else:\n",
    "                    dist.append(same_cls_diff_dom)\n",
    "            else:\n",
    "                if dom == other_dom:\n",
    "                    dist.append(diff_cls_same_dom)\n",
    "                else:\n",
    "                    dist.append(diff_cls_diff_dom)\n",
    "        dist = np.array(dist)\n",
    "        dist = dist / dist.sum()\n",
    "        result.append(dist)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Stochastic block model graph simulation\n",
    "'''\n",
    "\n",
    "def get_stochastic_block_adjacency(prob_dists):\n",
    "    '''\n",
    "    This function takes in a list of transition functions for each vertex and computes\n",
    "    the adjacency matrix\n",
    "    '''\n",
    "    N = len(prob_dists)\n",
    "    W = np.zeros((N, N))\n",
    "    for prob_dist in prob_dists:\n",
    "        for v_1 in range(N):\n",
    "            for v_2 in range(v_1, N):\n",
    "                W[v_1, v_2] += prob_dist[v_1] * prob_dist[v_2]\n",
    "    # Uniform over N data points, so divide\n",
    "    W = W / N\n",
    "    assert np.all(W == np.triu(W))\n",
    "    for row in range(N):\n",
    "        for col in range(row):\n",
    "            W[row, col] = W[col, row]\n",
    "    assert np.all(W.T == W) # Make sure it's symmetric\n",
    "    return W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Random Graph Simulation\n",
    "'''\n",
    "\n",
    "def get_random_adjacency(prob_dists, num_augs_per_vert):\n",
    "    '''\n",
    "    This function takes in a list of transition functions for each vertex and computes\n",
    "    the adjacency matrix by randomly selecting `num_augs_per_vert` augmentations for each\n",
    "    vertex (selected according to the transition probability function)\n",
    "    '''\n",
    "    N = len(prob_dists)\n",
    "    attempts = max_attempts = 0\n",
    "    while True:\n",
    "        W = np.zeros((N, N))\n",
    "        all_neighbors = []\n",
    "        for idx, prob_dist in enumerate(prob_dists):\n",
    "            neighbors = sorted(np.random.choice(len(prob_dist), size=num_augs_per_vert, p=prob_dist))\n",
    "            for i in range(num_augs_per_vert):\n",
    "                for j in range(i, num_augs_per_vert):\n",
    "                    W[neighbors[i], neighbors[j]] += 1\n",
    "            all_neighbors.append(neighbors)\n",
    "        # Uniform over num_augs_per_vert augmentations, so divide by that\n",
    "        W = W / (num_augs_per_vert ** 2)\n",
    "        # Uniform over the vertices so divide again\n",
    "        W = W / N\n",
    "        assert np.all(W == np.triu(W))\n",
    "        for row in range(N):\n",
    "            for col in range(row):\n",
    "                W[row, col] = W[col, row]\n",
    "        assert np.all(W.T == W) # Make sure it's symmetric\n",
    "        if np.all(W.sum(axis=1) > 0):\n",
    "            break\n",
    "        attempts += 1\n",
    "        if attempts == max_attempts:\n",
    "            raise RuntimeError\n",
    "    return W, all_neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_graph(\n",
    "    num_classes,\n",
    "    num_domains,\n",
    "    vert_per_clus,\n",
    "    self,\n",
    "    same_cls_same_dom,\n",
    "    same_cls_diff_dom,\n",
    "    diff_cls_same_dom,\n",
    "    diff_cls_diff_dom,\n",
    "    graph_type,\n",
    "    num_augs_per_vert=None\n",
    "):\n",
    "    vert_to_label, label_to_vert = init_vertices(num_classes, num_domains, vert_per_clus)\n",
    "    prob_dists = get_connectivity(\n",
    "        vert_to_label, self, same_cls_same_dom, same_cls_diff_dom, diff_cls_same_dom, diff_cls_diff_dom)\n",
    "    if graph_type == 'stochastic_block':\n",
    "        W = get_stochastic_block_adjacency(prob_dists)\n",
    "        return W, prob_dists, vert_to_label, label_to_vert\n",
    "    elif graph_type == 'random':\n",
    "        assert num_augs_per_vert is not None\n",
    "        W, neighbors = get_random_adjacency(prob_dists, num_augs_per_vert)\n",
    "        return W, neighbors, prob_dists, vert_to_label, label_to_vert\n",
    "    else:\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "SVD, Linear probing, Evaluation\n",
    "'''\n",
    "\n",
    "def get_normalized_adjacency(W):\n",
    "    '''\n",
    "    Computes a row- and col-normalized adjacency matrix from W\n",
    "    '''\n",
    "    D_neg_half = np.diag(W.sum(axis=1) ** (-1/2))\n",
    "    W_bar = D_neg_half @ W @ D_neg_half\n",
    "    # The rows don't actually sum to 1\n",
    "    # assert np.all(abs(W_bar.sum(axis=0) - 1) <= 1e-5), W_bar.sum(axis=0)\n",
    "    # assert np.all(abs(W_bar.sum(axis=1) - 1) <= 1e-5), W_bar.sum(axis=1)\n",
    "    return W_bar\n",
    "\n",
    "def get_svd_features(W_bar, num_features, ignore_first=True):\n",
    "    '''\n",
    "    Computes the SVD of the normalized adjacency matrix\n",
    "    '''\n",
    "    U, S, _ = np.linalg.svd(W_bar, hermitian=True)\n",
    "    if ignore_first:\n",
    "        feat = U[:, 1:1 + num_features]\n",
    "        sing = S[1:1 + num_features]\n",
    "    else:\n",
    "        feat = U[:, :num_features]\n",
    "        sing = S[:num_features]\n",
    "    return feat, sing\n",
    "\n",
    "def get_acc(clf, xs, ys):\n",
    "    preds = clf.predict(xs)\n",
    "    return np.mean(preds == ys)\n",
    "\n",
    "def train_linear_probe(X_train, y_train, X_test, y_test, random_state=0, should_print=True):\n",
    "    clf = LogisticRegression(\n",
    "        C=0.2,\n",
    "        random_state=random_state).fit(X_train, y_train)\n",
    "    train_acc = get_acc(clf, X_train, y_train)\n",
    "    test_acc = get_acc(clf, X_test, y_test)\n",
    "    if should_print:\n",
    "        print(f'Fitted the train set to accuracy: {train_acc} and the test set to accuracy: {test_acc}!')\n",
    "    return clf, train_acc, test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Stochastic block linear probing\n",
    "'''\n",
    "\n",
    "def get_stochastic_block_linear_probe_features(features, prob_dists, vert_to_label, label_to_vert, method):\n",
    "    if method == 'ensemble':\n",
    "        '''\n",
    "        For each train data point, ensemble the features of its neighbors according to\n",
    "        its probability distribution\n",
    "        '''\n",
    "        X_train, y_train, X_test, y_test = [], [], [], []\n",
    "        for i, prob_dist in enumerate(prob_dists):\n",
    "            cls, dom = vert_to_label[i]\n",
    "            ensembled_features = features * prob_dist[:, None]\n",
    "            ensembled_features = ensembled_features.sum(axis=0)\n",
    "            if dom == 0: # source\n",
    "                X_train.append(ensembled_features)\n",
    "                y_train.append(cls)\n",
    "            else: # target\n",
    "                X_test.append(ensembled_features)\n",
    "                y_test.append(cls)\n",
    "        return X_train, y_train, X_test, y_test\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "def run_stochastic_block_decomp_and_eval(\n",
    "    W, num_features, prob_dists, vert_to_label, label_to_vert, method, should_print=True\n",
    "):\n",
    "    W_bar = get_normalized_adjacency(W)\n",
    "    features, singular_values = get_svd_features(W_bar, num_features)\n",
    "    X_train, y_train, X_test, y_test = get_stochastic_block_linear_probe_features(\n",
    "        features, prob_dists, vert_to_label, label_to_vert, method)\n",
    "    clf, train_acc, test_acc = train_linear_probe(X_train, y_train, X_test, y_test, should_print=should_print)\n",
    "    return W_bar, features, singular_values, clf, train_acc, test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Random graph linear probing\n",
    "'''\n",
    "\n",
    "def get_random_linear_probe_features(features, neighbors, vert_to_label, label_to_vert, method):\n",
    "    if method == 'ensemble':\n",
    "        '''\n",
    "        For each train data point, ensemble the features of its neighbors according to\n",
    "        its probability distribution\n",
    "        '''\n",
    "        X_train, y_train, X_test, y_test = [], [], [], []\n",
    "        for i, neighbor_list in enumerate(neighbors):\n",
    "            cls, dom = vert_to_label[i]\n",
    "            ensembled_features = features[neighbor_list] * (1 / len(neighbor_list))\n",
    "            ensembled_features = ensembled_features.sum(axis=0)\n",
    "            if dom == 0: # source\n",
    "                X_train.append(ensembled_features)\n",
    "                y_train.append(cls)\n",
    "            else: # target\n",
    "                X_test.append(ensembled_features)\n",
    "                y_test.append(cls)\n",
    "        return X_train, y_train, X_test, y_test\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "def run_random_decomp_and_eval(\n",
    "    W, num_features, neighbors, vert_to_label, label_to_vert, method, should_print=True\n",
    "):\n",
    "    W_bar = get_normalized_adjacency(W)\n",
    "    features, singular_values = get_svd_features(W_bar, num_features)\n",
    "    X_train, y_train, X_test, y_test = get_random_linear_probe_features(\n",
    "        features, neighbors, vert_to_label, label_to_vert, method)\n",
    "    clf, train_acc, test_acc = train_linear_probe(X_train, y_train, X_test, y_test, should_print=should_print)\n",
    "    return W_bar, features, singular_values, clf, train_acc, test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Wrapper functions\n",
    "'''\n",
    "\n",
    "def do_stochastic_block(\n",
    "    num_classes,\n",
    "    num_domains,\n",
    "    vert_per_clus,\n",
    "    self,\n",
    "    same_cls_same_dom,\n",
    "    same_cls_diff_dom,\n",
    "    diff_cls_same_dom,\n",
    "    diff_cls_diff_dom,\n",
    "    should_print=False\n",
    "):\n",
    "    W, prob_dists, vert_to_label, label_to_vert = create_graph(\n",
    "        num_classes, num_domains, vert_per_clus, self,\n",
    "        same_cls_same_dom, same_cls_diff_dom, diff_cls_same_dom, diff_cls_diff_dom,\n",
    "        'stochastic_block')\n",
    "    W_bar, features, singular_values, clf, train_acc, test_acc = run_stochastic_block_decomp_and_eval(\n",
    "        W, 2 * num_classes * num_domains, prob_dists,\n",
    "        vert_to_label, label_to_vert, 'ensemble', should_print=should_print)\n",
    "    return W, prob_dists, vert_to_label, label_to_vert, \\\n",
    "        W_bar, features, singular_values, clf, train_acc, test_acc\n",
    "\n",
    "def do_random_graph(\n",
    "    num_classes,\n",
    "    num_domains,\n",
    "    vert_per_clus,\n",
    "    self,\n",
    "    same_cls_same_dom,\n",
    "    same_cls_diff_dom,\n",
    "    diff_cls_same_dom,\n",
    "    diff_cls_diff_dom,\n",
    "    num_augs_per_vert,\n",
    "    should_print=False\n",
    "):\n",
    "    W, neighbors, prob_dists, vert_to_label, label_to_vert = create_graph(\n",
    "        num_classes, num_domains, vert_per_clus, self,\n",
    "        same_cls_same_dom, same_cls_diff_dom, diff_cls_same_dom, diff_cls_diff_dom,\n",
    "        'random', num_augs_per_vert)\n",
    "    W_bar, features, singular_values, clf, train_acc, test_acc = run_random_decomp_and_eval(\n",
    "        W, 2 * num_classes * num_domains, neighbors,\n",
    "        vert_to_label, label_to_vert, 'ensemble', should_print=should_print)\n",
    "    return W, neighbors, prob_dists, vert_to_label, label_to_vert, \\\n",
    "        W_bar, features, singular_values, clf, train_acc, test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r df_stochastic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r df_random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Illustrative example\n",
    "\n",
    "- In the super simple stochastic block model, whenever the connectivity satisfies `same_class_diff_domain > diff_class_diff_domain`, the test accuracy is 100%.\n",
    "- Each eigenvector has entries that sum to 1 (since it has to be orthogonal to the leading all-ones eigenvector). However, some eigenvectors contain domain information and others contain class information, and the linear probe trained on the source data places 0 weight on the main eigenvector (largest eigenvalue) that contains domain information.\n",
    "- Thus, under the stochastic block model, as long as the \"good connectivity\" is satisfied then the linear classifier avoids using the domain eigenvector.\n",
    "- Is this behavior replicated in the random graph model?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "W, prob_dists, vert_to_label, label_to_vert, W_bar, features, singular_values, clf, train_acc, test_acc = do_stochastic_block(\n",
    "    num_classes=3,\n",
    "    num_domains=2,\n",
    "    vert_per_clus=5,\n",
    "    self=10,\n",
    "    same_cls_same_dom=8,\n",
    "    same_cls_diff_dom=4,\n",
    "    diff_cls_same_dom=6,\n",
    "    diff_cls_diff_dom=3.9,\n",
    "    should_print=True\n",
    ")\n",
    "N = 3 * 2 * 5\n",
    "\n",
    "df = pd.DataFrame(np.around(features, 3))\n",
    "df.columns = list(map(str, range(len(df.columns))))\n",
    "df['cls'] = [vert_to_label[i][0] for i in range(N)]\n",
    "df['dom'] = [vert_to_label[i][1] for i in range(N)]\n",
    "# low_dim_feat = TSNE(n_components=3).fit_transform(np.around(features, 3))\n",
    "# low_dim_feat = PCA(n_components=3).fit_transform(np.around(features, 3))\n",
    "# df['tsne0'] = low_dim_feat[:, 0]\n",
    "# df['tsne1'] = low_dim_feat[:, 1]\n",
    "# df['tsne2'] = low_dim_feat[:, 2]\n",
    "# fig = px.scatter_3d(df, x='tsne0', y='tsne1', z='tsne2', color='cls', symbol='dom')\n",
    "fig = px.scatter_3d(df, x='0', y='1', z='2', color='cls', symbol='dom')\n",
    "fig.update_traces(marker=dict(size=8), selector=dict(mode='markers'))\n",
    "fig.show()\n",
    "# display(df)\n",
    "\n",
    "normal_vectors = pd.DataFrame(np.around(clf.coef_, 3))\n",
    "normal_vectors.columns = list(map(str, range(len(normal_vectors.columns))))\n",
    "normal_vectors['idx'] = normal_vectors.index\n",
    "fig_2 = px.scatter_3d(normal_vectors, x='0', y='1', z='2', color='idx')\n",
    "fig_2.update_traces(marker=dict(size=8), selector=dict(mode='markers'))\n",
    "fig_2.show()\n",
    "# display(normal_vectors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "W, neighbors, prob_dists, vert_to_label, label_to_vert, W_bar, features, singular_values, clf, train_acc, test_acc = do_random_graph(\n",
    "    num_classes=3,\n",
    "    num_domains=2,\n",
    "    vert_per_clus=20,\n",
    "    self=10,\n",
    "    same_cls_same_dom=8,\n",
    "    same_cls_diff_dom=4,\n",
    "    diff_cls_same_dom=6,\n",
    "    diff_cls_diff_dom=3.9,\n",
    "    num_augs_per_vert=10,\n",
    "    should_print=True\n",
    ")\n",
    "N = 3 * 2 * 20\n",
    "\n",
    "df = pd.DataFrame(np.around(features, 3))\n",
    "df.columns = list(map(str, range(len(df.columns))))\n",
    "df['cls'] = [vert_to_label[i][0] for i in range(N)]\n",
    "df['dom'] = [vert_to_label[i][1] for i in range(N)]\n",
    "# low_dim_feat = TSNE(n_components=3).fit_transform(np.around(features, 3))\n",
    "# low_dim_feat = PCA(n_components=3).fit_transform(np.around(features, 3))\n",
    "# df['tsne0'] = low_dim_feat[:, 0]\n",
    "# df['tsne1'] = low_dim_feat[:, 1]\n",
    "# df['tsne2'] = low_dim_feat[:, 2]\n",
    "# fig = px.scatter_3d(df, x='tsne0', y='tsne1', z='tsne2', color='cls', symbol='dom')\n",
    "fig = px.scatter_3d(df, x='0', y='1', z='2', color='cls', symbol='dom')\n",
    "fig.update_traces(marker=dict(size=8), selector=dict(mode='markers'))\n",
    "fig.show()\n",
    "display(df)\n",
    "\n",
    "normal_vectors = pd.DataFrame(np.around(clf.coef_, 3))\n",
    "normal_vectors.columns = list(map(str, range(len(normal_vectors.columns))))\n",
    "normal_vectors['idx'] = normal_vectors.index\n",
    "fig_2 = px.scatter_3d(normal_vectors, x='0', y='1', z='2', color='idx')\n",
    "fig_2.update_traces(marker=dict(size=8), selector=dict(mode='markers'))\n",
    "fig_2.show()\n",
    "display(normal_vectors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Connectivity threshold plotting\n",
    "\n",
    "- In this section, the 3 primary forms of connectivity are swept over and plotted together to show the thresholding behavior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Run simulation for several settings and save it in a dataframe (stochastic block model)\n",
    "'''\n",
    "\n",
    "NUM_CLASSES = 10\n",
    "NUM_DOMAINS = 2\n",
    "VERT_PER_CLUS = 8\n",
    "SELF = 10\n",
    "SAME_CLS_SAME_DOM = 8\n",
    "ITERATIONS = 1\n",
    "\n",
    "df_stochastic = pd.DataFrame()\n",
    "for it in range(ITERATIONS):\n",
    "    for SAME_CLS_DIFF_DOM in tqdm(range(0, 11, 2), leave=False):\n",
    "        for DIFF_CLS_SAME_DOM in range(0, 11, 2):\n",
    "            for DIFF_CLS_DIFF_DOM in range(0, 11, 2):\n",
    "                result = do_stochastic_block(\n",
    "                    NUM_CLASSES,\n",
    "                    NUM_DOMAINS,\n",
    "                    VERT_PER_CLUS,\n",
    "                    SELF,\n",
    "                    SAME_CLS_SAME_DOM,\n",
    "                    SAME_CLS_DIFF_DOM,\n",
    "                    DIFF_CLS_SAME_DOM,\n",
    "                    DIFF_CLS_DIFF_DOM\n",
    "                )\n",
    "                df_stochastic = df_stochastic.append({\n",
    "                    'num_classes': NUM_CLASSES,\n",
    "                    'vert_per_clus': VERT_PER_CLUS,\n",
    "                    'same_cls_same_dom': SAME_CLS_SAME_DOM,\n",
    "                    'same_cls_diff_dom': SAME_CLS_DIFF_DOM,\n",
    "                    'diff_cls_same_dom': DIFF_CLS_SAME_DOM,\n",
    "                    'diff_cls_diff_dom': DIFF_CLS_DIFF_DOM,\n",
    "                    'train_acc': result[-2],\n",
    "                    'test_acc': result[-1],\n",
    "                    'iteration': it\n",
    "                }, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "This cell contains visualization code for sensitivity / cross-sectional analyses\n",
    "'''\n",
    "\n",
    "fig = px.scatter_3d(\n",
    "    df_stochastic[\n",
    "        (df_stochastic['diff_cls_same_dom'] == 4)\n",
    "    ],\n",
    "    x='same_cls_diff_dom', y='diff_cls_diff_dom', z='test_acc'\n",
    ")\n",
    "fig.update_traces(marker=dict(size=8), selector=dict(mode='markers'))\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "'''\n",
    "Run simulation for several settings and save it in a dataframe (random graph model)\n",
    "'''\n",
    "\n",
    "NUM_CLASSES = 10\n",
    "NUM_DOMAINS = 2\n",
    "VERT_PER_CLUS = 8\n",
    "SELF = 10\n",
    "SAME_CLS_SAME_DOM = 8\n",
    "NUM_AUGS_PER_VERT = 10\n",
    "ITERATIONS = 5\n",
    "\n",
    "df_random = pd.DataFrame()\n",
    "for it in range(ITERATIONS):\n",
    "    for SAME_CLS_DIFF_DOM in tqdm(range(0, 11, 2), leave=False):\n",
    "        for DIFF_CLS_SAME_DOM in range(0, 11, 2):\n",
    "            for DIFF_CLS_DIFF_DOM in range(0, 11, 2):\n",
    "                result = do_random_graph(\n",
    "                    NUM_CLASSES,\n",
    "                    NUM_DOMAINS,\n",
    "                    VERT_PER_CLUS,\n",
    "                    SELF,\n",
    "                    SAME_CLS_SAME_DOM,\n",
    "                    SAME_CLS_DIFF_DOM,\n",
    "                    DIFF_CLS_SAME_DOM,\n",
    "                    DIFF_CLS_DIFF_DOM,\n",
    "                    NUM_AUGS_PER_VERT\n",
    "                )\n",
    "                df_random = df_random.append({\n",
    "                    'num_classes': NUM_CLASSES,\n",
    "                    'vert_per_clus': VERT_PER_CLUS,\n",
    "                    'same_cls_same_dom': SAME_CLS_SAME_DOM,\n",
    "                    'same_cls_diff_dom': SAME_CLS_DIFF_DOM,\n",
    "                    'diff_cls_same_dom': DIFF_CLS_SAME_DOM,\n",
    "                    'diff_cls_diff_dom': DIFF_CLS_DIFF_DOM,\n",
    "                    'train_acc': result[-2],\n",
    "                    'test_acc': result[-1],\n",
    "                    'iteration': it\n",
    "                }, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store df_stochastic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store df_random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "This cell contains visualization code for sensitivity / cross-sectional analyses\n",
    "'''\n",
    "\n",
    "avg_df_random_orig = df_random_orig.groupby(['num_classes', 'vert_per_clus', 'same_cls_same_dom', 'same_cls_diff_dom',\n",
    "                                   'diff_cls_same_dom', 'diff_cls_diff_dom']).mean().reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "This cell contains visualization code for sensitivity / cross-sectional analyses\n",
    "'''\n",
    "\n",
    "avg_df_random = df_random.groupby(['num_classes', 'vert_per_clus', 'same_cls_same_dom', 'same_cls_diff_dom',\n",
    "                                   'diff_cls_same_dom', 'diff_cls_diff_dom']).mean().reset_index()\n",
    "avg_df_random['acc_ratio'] = avg_df_random['test_acc'] / avg_df_random['train_acc']\n",
    "avg_df_random['test_acc_diff'] = avg_df_random['test_acc'] / avg_df_random_orig['test_acc']\n",
    "\n",
    "fig = px.scatter_3d(\n",
    "    avg_df_random[\n",
    "        (avg_df_random['diff_cls_same_dom'] == 2)\n",
    "    ],\n",
    "    x='same_cls_diff_dom', y='diff_cls_diff_dom', z='test_acc'\n",
    ")\n",
    "fig.update_traces(marker=dict(size=8), selector=dict(mode='markers'))\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "'''\n",
    "This cell contains visualization code for features / centroids\n",
    "'''\n",
    "\n",
    "df = pd.DataFrame(np.around(features, 3))\n",
    "df.columns = list(map(str, range(len(df.columns))))\n",
    "df['cls'] = [vert_to_label[i][0] for i in range(N)]\n",
    "df['dom'] = [vert_to_label[i][1] for i in range(N)]\n",
    "# low_dim_feat = TSNE(n_components=3).fit_transform(np.around(features, 3))\n",
    "# low_dim_feat = PCA(n_components=3).fit_transform(np.around(features, 3))\n",
    "# df['tsne0'] = low_dim_feat[:, 0]\n",
    "# df['tsne1'] = low_dim_feat[:, 1]\n",
    "# df['tsne2'] = low_dim_feat[:, 2]\n",
    "# fig = px.scatter_3d(df, x='tsne0', y='tsne1', z='tsne2', color='cls', symbol='dom')\n",
    "fig = px.scatter_3d(df, x='0', y='1', z='2', color='cls', symbol='dom')\n",
    "fig.update_traces(marker=dict(size=8), selector=dict(mode='markers'))\n",
    "fig.show()\n",
    "display(df)\n",
    "\n",
    "normal_vectors = pd.DataFrame(np.around(clf.coef_, 3))\n",
    "normal_vectors.columns = list(map(str, range(len(normal_vectors.columns))))\n",
    "normal_vectors['idx'] = normal_vectors.index\n",
    "fig_2 = px.scatter_3d(normal_vectors, x='0', y='1', z='2', color='idx')\n",
    "fig_2.update_traces(marker=dict(size=8), selector=dict(mode='markers'))\n",
    "fig_2.show()\n",
    "display(normal_vectors)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## &#x1F53B; Note: cells below this are outdated &#x1F53B;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Run simulation for several settings and save it in a dataframe\n",
    "'''\n",
    "\n",
    "NUM_DOMAINS = 2\n",
    "VERT_PER_CLUS = 2\n",
    "SELF = 10\n",
    "\n",
    "df = pd.DataFrame()\n",
    "for NUM_CLASSES in [3, 10, 20]:\n",
    "    N = NUM_CLASSES * NUM_DOMAINS * VERT_PER_CLUS\n",
    "    for SAME_CLS_SAME_DOM_WT in [8, 12]:\n",
    "        for SAME_CLS_SAME_DOM_MULT in [1, 1 / VERT_PER_CLUS]:\n",
    "            SAME_CLS_SAME_DOM = SAME_CLS_SAME_DOM_WT * SAME_CLS_SAME_DOM_MULT\n",
    "            for SAME_CLS_DIFF_DOM_WT in [6, 10]:\n",
    "                for SAME_CLS_DIFF_DOM_MULT in [1, 1 / VERT_PER_CLUS]:\n",
    "                    SAME_CLS_DIFF_DOM = SAME_CLS_DIFF_DOM_WT * SAME_CLS_DIFF_DOM_MULT\n",
    "                    for DIFF_CLS_SAME_DOM_WT in [6, 10]:\n",
    "                        for DIFF_CLS_SAME_DOM_MULT in [1, 1 / ((NUM_CLASSES - 1) * VERT_PER_CLUS)]:\n",
    "                            DIFF_CLS_SAME_DOM = DIFF_CLS_SAME_DOM_WT * DIFF_CLS_SAME_DOM_MULT\n",
    "                            for DIFF_CLS_DIFF_DOM_WT in [5, 8, 12]:\n",
    "                                for DIFF_CLS_DIFF_DOM_MULT in [1, 1 / ((NUM_CLASSES - 1) * VERT_PER_CLUS)]:\n",
    "                                    DIFF_CLS_DIFF_DOM = DIFF_CLS_DIFF_DOM_WT * DIFF_CLS_DIFF_DOM_MULT\n",
    "                                    result = do_stochastic_block(\n",
    "                                        NUM_CLASSES,\n",
    "                                        NUM_DOMAINS,\n",
    "                                        VERT_PER_CLUS,\n",
    "                                        SELF,\n",
    "                                        SAME_CLS_SAME_DOM,\n",
    "                                        SAME_CLS_DIFF_DOM,\n",
    "                                        DIFF_CLS_SAME_DOM,\n",
    "                                        DIFF_CLS_DIFF_DOM\n",
    "                                    )\n",
    "                                    df = df.append({\n",
    "                                        'num_classes': NUM_CLASSES,\n",
    "                                        'vert_per_clus': VERT_PER_CLUS,\n",
    "                                        'same_cls_same_dom': SAME_CLS_SAME_DOM,\n",
    "                                        'same_cls_diff_dom': SAME_CLS_DIFF_DOM,\n",
    "                                        'diff_cls_same_dom': DIFF_CLS_SAME_DOM,\n",
    "                                        'diff_cls_diff_dom': DIFF_CLS_DIFF_DOM,\n",
    "                                        'train_acc': result[-2],\n",
    "                                        'test_acc': result[-1]\n",
    "                                    }, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# try out a bunch of different configurations\n",
    "fraction = 0.2\n",
    "num_per_axis = 6\n",
    "for it in range(iters):\n",
    "    res = {}\n",
    "    for btwn_class in tqdm(np.linspace(1, 3, num_per_axis)):\n",
    "        for btwn_domain in np.linspace(1, 3, num_per_axis):\n",
    "            for btwn_both in np.linspace(1, 3, num_per_axis):\n",
    "                key = (btwn_class, btwn_domain, btwn_both)\n",
    "                natural, _, reverse, _, _, _, features = stochastic_block_model(\n",
    "                    fraction, btwn_class, btwn_domain, btwn_both, NUM_CLUS * 2)\n",
    "                res[key] = (\n",
    "                    linear_probe(features, reverse, print_acc=False),\n",
    "                    linear_probe_ensemble(features, reverse, natural, print_acc=False)\n",
    "                )\n",
    "    results.append(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check the locations of the centroids\n",
    "def get_centroids():\n",
    "    centroids = {}\n",
    "    for dom in range(NUM_DOMAINS):\n",
    "        for cls in range(NUM_CLASSES):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if desired, can plot the tsne embeddings of the features\n",
    "def visualize(features, n_components=2):\n",
    "    # dimensionality reduction\n",
    "    tsne = TSNE(n_components=n_components)\n",
    "    low_dim_feat = tsne.fit_transform(features)\n",
    "    # plot with colors\n",
    "    c = ['Red', 'Orange', 'Yellow', 'Green', 'Blue', 'Indigo', 'Violet', 'Black']\n",
    "    colors = [[c[i]] * AUG_PER_CLUS for i in range(NUM_CLUS)]\n",
    "    colors = [j for sub in colors for j in sub]\n",
    "    plt.scatter(low_dim_feat[:, 0], low_dim_feat[:, 1], c=colors)\n",
    "    plt.show()\n",
    "    # also report the color used for each (class, domain) cluster\n",
    "    idx = 0\n",
    "    for dom in range(NUM_DOMAINS):\n",
    "        for cls in range(NUM_CLASSES):\n",
    "            print('class', cls, 'domain', dom, 'is', c[idx])\n",
    "            idx += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot various things with ensembling too\n",
    "df = pd.DataFrame()\n",
    "for btwn_class in np.linspace(1, 3, num_per_axis):\n",
    "    for btwn_domain in np.linspace(1, 3, num_per_axis):\n",
    "        for btwn_both in np.linspace(1, 3, num_per_axis):\n",
    "            key = (btwn_class, btwn_domain, btwn_both)\n",
    "            data = {\n",
    "                'btwn_class': btwn_class,\n",
    "                'btwn_domain': btwn_domain,\n",
    "                'btwn_both': btwn_both,\n",
    "                'avg_source_acc': np.mean([res[key][0][0] for res in results]),\n",
    "                'avg_target_acc': np.mean([res[key][0][1] for res in results]),\n",
    "                'avg_source_ensemble_acc': np.mean([res[key][1][0] for res in results]),\n",
    "                'avg_target_ensemble_acc': np.mean([res[key][1][1] for res in results])\n",
    "            }\n",
    "            df = df.append(data, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = px.scatter_3d(df[df['btwn_both'] == 1.0], x='btwn_class', y='btwn_domain', z='avg_target_ensemble_acc')\n",
    "fig.update_traces(marker=dict(size=4), selector=dict(mode='markers'))\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot various things\n",
    "df = pd.DataFrame()\n",
    "for (btwn_class, btwn_domain, btwn_both), (source_acc, target_acc) in results.items():\n",
    "    data = {\n",
    "        'btwn_class': btwn_class,\n",
    "        'btwn_domain': btwn_domain,\n",
    "        'btwn_both': btwn_both,\n",
    "        'target_acc': target_acc\n",
    "    }\n",
    "    df = df.append(data, ignore_index=True)\n",
    "# fig = px.scatter_3d(df, x='btwn_class', y='btwn_domain', z='btwn_both', color='target_acc')\n",
    "fig = px.scatter_3d(df, x='btwn_class', y='btwn_domain', z='target_acc')\n",
    "fig.update_traces(marker=dict(size=4), selector=dict(mode='markers'))\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "btwn_class = 1.0\n",
    "btwn_both = 1.0\n",
    "rows = df[(df['btwn_class'] == btwn_class) & (df['btwn_both'] == btwn_both)]\n",
    "plt.plot(rows['btwn_domain'], rows['target_acc'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "btwn_class = 1.0\n",
    "btwn_both = 1.0\n",
    "rows = df[(df['btwn_class'] == btwn_class) & (df['btwn_both'] == btwn_both)]\n",
    "plt.plot(rows['btwn_domain'], rows['avg_target_ensemble_acc'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# thought experiment: each (class, domain) pair is a singleton\n",
    "# define the connectivity via class_domain\n",
    "diff_same = 1.\n",
    "same_diff = 2.\n",
    "diff_diff = 0.2\n",
    "W = np.array([\n",
    "    [0, diff_same, diff_same, same_diff, diff_diff, diff_diff],\n",
    "    [diff_same, 0, diff_same, diff_diff, same_diff, diff_diff],\n",
    "    [diff_same, diff_same, 0, diff_diff, diff_diff, same_diff],\n",
    "    [same_diff, diff_diff, diff_diff, 0, diff_same, diff_same],\n",
    "    [diff_diff, same_diff, diff_diff, diff_same, 0, diff_same],\n",
    "    [diff_diff, diff_diff, same_diff, diff_same, diff_same, 0]\n",
    "])\n",
    "assert np.all(W == W.T)\n",
    "D = np.diag(np.sum(W, axis=1) ** (-1/2))\n",
    "W_bar = D @ W @ D\n",
    "U, S, VT = np.linalg.svd(W_bar)\n",
    "colors = ['purple', 'red', 'green']\n",
    "markers = ['o', '^']\n",
    "for i in range(6):\n",
    "    plt.scatter(U[i, 1], U[i, 2], c=colors[i % 3], marker=markers[i // 3])\n",
    "plt.show()\n",
    "df = pd.DataFrame(np.around(U, 3))\n",
    "df.columns = list(map(str, range(6)))\n",
    "df['c'] = colors * 2\n",
    "df['d'] = [markers[0]] * 3 + [markers[1]] * 3\n",
    "fig = px.scatter_3d(df, x='1', y='2', z='3', color='c', symbol='d')\n",
    "fig.update_traces(marker=dict(size=8), selector=dict(mode='markers'))\n",
    "fig.show()\n",
    "df"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
