{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sopHPgEhu4Jo"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torch.autograd import grad\n",
        "from torchvision import transforms\n",
        "from torchvision import datasets\n",
        "import torchvision.datasets.utils as dataset_utils\n",
        "\n",
        "from sklearn.datasets import make_classification\n",
        "from sklearn import tree\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "import random\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kkTMK-oJveGg"
      },
      "source": [
        "## Prepare the CMNIST and CPMNIST datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q5MGpR2iAgHa"
      },
      "outputs": [],
      "source": [
        "seed=3\n",
        "torch.manual_seed(seed)\n",
        "np.random.seed(seed)\n",
        "random.seed(seed)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "knP-xNzavgAb"
      },
      "outputs": [],
      "source": [
        "# Base code from https://colab.research.google.com/github/reiinakano/invariant-risk-minimization/blob/master/invariant_risk_minimization_colored_mnist.ipynb\n",
        "\n",
        "def color_grayscale_arr(arr, red=True):\n",
        "  \"\"\"Converts grayscale image to either red or green\"\"\"\n",
        "  assert arr.ndim == 2\n",
        "  dtype = arr.dtype\n",
        "  h, w = arr.shape\n",
        "  arr = np.reshape(arr, [h, w, 1])\n",
        "  if red:\n",
        "    arr = np.concatenate([arr,\n",
        "                          np.zeros((h, w, 2), dtype=dtype)], axis=2)\n",
        "  else:\n",
        "    arr = np.concatenate([np.zeros((h, w, 1), dtype=dtype),\n",
        "                          arr,\n",
        "                          np.zeros((h, w, 1), dtype=dtype)], axis=2)\n",
        "  return arr\n",
        "\n",
        "def patch_img(arr, horz=True):\n",
        "  \"\"\"Adds patch to img in top left corner\"\"\"\n",
        "  #assert arr.ndim == 3\n",
        "  if arr.ndim<3:\n",
        "    arr=np.expand_dims(arr,axis=-1)\n",
        "    arr=np.concatenate([arr,arr,arr],axis=2)\n",
        "  #print(d)\n",
        "  dtype = arr.dtype\n",
        "  h, w, d = arr.shape\n",
        "  p = 0.0\n",
        "  k=d\n",
        "  \n",
        "  \n",
        "  if horz:\n",
        "    for i in range(4):\n",
        "      for j in range(4):\n",
        "        arr[i,j,:]=255\n",
        "  else:\n",
        "    for i in range(4):\n",
        "      for j in range(4):\n",
        "        arr[27-i,27-j,:]=255\n",
        "  return arr\n",
        "\n",
        "def line_img(arr, horz=True):\n",
        "  \"\"\"Adds vertical line to img in the right half\"\"\"\n",
        "  #assert arr.ndim == 3\n",
        "  if arr.ndim<3:\n",
        "    arr=np.expand_dims(arr,axis=-1)\n",
        "    arr=np.concatenate([arr,arr,arr],axis=2)\n",
        "  \n",
        "  if horz:\n",
        "    for i in range(25,27):\n",
        "        arr[:,i,:]=255\n",
        "  return arr\n",
        "\n",
        "\n",
        "class ColoredMNIST(datasets.VisionDataset):\n",
        "  \"\"\"\n",
        "  Colored MNIST dataset for testing IRM. Prepared using procedure from https://arxiv.org/pdf/1907.02893.pdf\n",
        "\n",
        "  Args:\n",
        "    root (string): Root directory of dataset where ``ColoredMNIST/*.pt`` will exist.\n",
        "    env (string): Which environment to load. Must be 1 of 'train1', 'train2', 'test', or 'all_train'.\n",
        "    transform (callable, optional): A function/transform that  takes in an PIL image\n",
        "      and returns a transformed version. E.g, ``transforms.RandomCrop``\n",
        "    target_transform (callable, optional): A function/transform that takes in the\n",
        "      target and transforms it.\n",
        "  \"\"\"\n",
        "  def __init__(self, root='./data', env='train1', transform=None, target_transform=None):\n",
        "    super(ColoredMNIST, self).__init__(root, transform=transform,\n",
        "                                target_transform=target_transform)\n",
        "\n",
        "    self.prepare_colored_mnist()\n",
        "    if env in ['train1', 'train2', 'test','trainc','testc','traino','testo','train1b', 'train2b', 'testb','traincb','testcb','train1l', 'train2l', 'testl','traincl','testcl','testd','testcd']:\n",
        "      self.data_label_tuples = torch.load(os.path.join(self.root, 'ColoredMNIST', env) + '.pt')\n",
        "    elif env == 'all_train':\n",
        "      self.data_label_tuples = torch.load(os.path.join(self.root, 'ColoredMNIST', 'train1.pt')) + \\\n",
        "                               torch.load(os.path.join(self.root, 'ColoredMNIST', 'train2.pt'))\n",
        "    elif env == 'all_trainb':\n",
        "      self.data_label_tuples = torch.load(os.path.join(self.root, 'ColoredMNIST', 'train1b.pt')) + \\\n",
        "                               torch.load(os.path.join(self.root, 'ColoredMNIST', 'train2b.pt'))\n",
        "    elif env == 'all_trainl':\n",
        "      self.data_label_tuples = torch.load(os.path.join(self.root, 'ColoredMNIST', 'train1l.pt')) + \\\n",
        "                               torch.load(os.path.join(self.root, 'ColoredMNIST', 'train2l.pt'))                              \n",
        "    else:\n",
        "      raise RuntimeError(f'{env} env unknown. Valid envs are train1, train2, test, and all_train')\n",
        "\n",
        "  def __getitem__(self, index):\n",
        "    \"\"\"\n",
        "    Args:\n",
        "        index (int): Index\n",
        "\n",
        "    Returns:\n",
        "        tuple: (image, target) where target is index of the target class.\n",
        "    \"\"\"\n",
        "    img, target = self.data_label_tuples[index]\n",
        "\n",
        "    if self.transform is not None:\n",
        "      img = self.transform(img)\n",
        "\n",
        "    if self.target_transform is not None:\n",
        "      target = self.target_transform(target)\n",
        "\n",
        "    return img, target\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.data_label_tuples)\n",
        "\n",
        "  def prepare_colored_mnist(self):\n",
        "    colored_mnist_dir = os.path.join(self.root, 'ColoredMNIST')\n",
        "    if os.path.exists(os.path.join(colored_mnist_dir, 'train1.pt')) \\\n",
        "        and os.path.exists(os.path.join(colored_mnist_dir, 'train2.pt')) \\\n",
        "        and os.path.exists(os.path.join(colored_mnist_dir, 'test.pt')):\n",
        "      print('Colored MNIST dataset already exists')\n",
        "      return\n",
        "\n",
        "    print('Preparing Colored MNIST')\n",
        "    train_mnist = datasets.mnist.MNIST(self.root, train=True, download=True)\n",
        "    test_mnist = datasets.mnist.MNIST(self.root, train=False, download=True)\n",
        "\n",
        "    train1_set = []\n",
        "    train2_set = []\n",
        "    test_set = []\n",
        "\n",
        "    trainc_set = []\n",
        "    testc_set = []\n",
        "\n",
        "    train1_setb = []\n",
        "    train2_setb = []\n",
        "    test_setb = []\n",
        "\n",
        "    trainc_setb = []\n",
        "    testc_setb = []\n",
        "\n",
        "    train1_setl = []\n",
        "    train2_setl = []\n",
        "    test_setl = []\n",
        "\n",
        "    trainc_setl = []\n",
        "    testc_setl = []\n",
        "\n",
        "    #train1_seto = []\n",
        "    train_seto = []\n",
        "    test_seto = []\n",
        "\n",
        "    test_setd = []\n",
        "    testc_setd = []\n",
        "\n",
        "    for idx, (im, label) in enumerate(train_mnist):\n",
        "      if idx % 10000 == 0:\n",
        "        print(f'Converting image {idx}/{len(train_mnist)}')\n",
        "      im_array = np.array(im)\n",
        "      arr=np.expand_dims(im_array,axis=-1)\n",
        "      arr=np.concatenate([arr,arr,arr],axis=2)\n",
        "\n",
        "      # Assign a binary label y to the image based on the digit\n",
        "      binary_label = 0 if label < 5 else 1\n",
        "      label = binary_label\n",
        "\n",
        "      # Flip label with 25% probability\n",
        "      if np.random.uniform() < 0.25:\n",
        "        binary_label = binary_label ^ 1\n",
        "\n",
        "      # Color the image either red or green according to its possibly flipped label\n",
        "      color_red = binary_label == 0\n",
        "      lin = binary_label == 0\n",
        "\n",
        "      # Flip the color with a probability e that depends on the environment\n",
        "      if idx < 25000:\n",
        "        # 20% in the first training environment\n",
        "        if np.random.uniform() < 0.2:\n",
        "          color_red = not color_red\n",
        "        if np.random.uniform() < 0.2:\n",
        "          lin = not lin\n",
        "      elif idx < 50000:\n",
        "        # 10% in the first training environment\n",
        "        if np.random.uniform() < 0.1:\n",
        "          color_red = not color_red\n",
        "        if np.random.uniform() < 0.1:\n",
        "          lin = not lin\n",
        "      #else:\n",
        "      #  # 90% in the test environment\n",
        "      #  if np.random.uniform() < 0.9:\n",
        "      #    color_red = not color_red\n",
        "      #  if np.random.uniform() < 0.9:\n",
        "      #    lin = not lin\n",
        "\n",
        "      colored_arr = color_grayscale_arr(im_array, red=color_red)\n",
        "      im_array = colored_arr.copy()\n",
        "      patch_arr = patch_img(im_array, horz=lin)\n",
        "      #im_array = patch_arr.copy()\n",
        "      #line_arr = line_img(im_array, horz=lin)\n",
        "\n",
        "      if idx < 25000:\n",
        "        train1_set.append((Image.fromarray(colored_arr), binary_label))       #CMNIST\n",
        "        trainc_set.append((Image.fromarray(colored_arr), int(not color_red))) #CMNIST with color labels\n",
        "\n",
        "        train1_setb.append((Image.fromarray(patch_arr), binary_label))        #PCMNIST\n",
        "        trainc_setb.append((Image.fromarray(patch_arr), int(1*int(not color_red)+2*int(not lin))))  #PCMNIST with 1*color+2*patch labels\n",
        "        \n",
        "        #train1_setl.append((Image.fromarray(line_arr), binary_label))         #PCMNIST-e1\n",
        "        #trainc_setl.append((Image.fromarray(line_arr), int(1*int(not color_red)+2*int(not lin))))  #PLMNIST with 1*patch+2*line labels\n",
        "\n",
        "        train_seto.append((Image.fromarray(arr), label))                      #org mnist\n",
        "      elif idx < 50000:\n",
        "        train2_set.append((Image.fromarray(colored_arr), binary_label))\n",
        "        trainc_set.append((Image.fromarray(colored_arr), int(not color_red)))\n",
        "\n",
        "        train2_setb.append((Image.fromarray(patch_arr), binary_label))\n",
        "        trainc_setb.append((Image.fromarray(patch_arr), int(1*int(not color_red)+2*int(not lin))))\n",
        "\n",
        "        #train2_setl.append((Image.fromarray(line_arr), binary_label))         #PLMNIST\n",
        "        #trainc_setl.append((Image.fromarray(line_arr), int(1*int(not color_red)+2*int(not lin))))  #PLMNIST with 1*patch+2*line labels\n",
        "\n",
        "        train_seto.append((Image.fromarray(arr), label))\n",
        "      \n",
        "    for idx, (im, label) in enumerate(test_mnist):\n",
        "      if idx % 10000 == 0:\n",
        "        print(f'Converting image {idx}/{len(test_mnist)}')\n",
        "      im_array = np.array(im)\n",
        "      arr=np.expand_dims(im_array,axis=-1)\n",
        "      arr=np.concatenate([arr,arr,arr],axis=2)\n",
        "\n",
        "      # Assign a binary label y to the image based on the digit\n",
        "      binary_label = 0 if label < 5 else 1\n",
        "      label = binary_label\n",
        "\n",
        "      # Flip label with 25% probability\n",
        "      if np.random.uniform() < 0.25:\n",
        "        binary_label = binary_label ^ 1\n",
        "\n",
        "      # Color the image either red or green according to its possibly flipped label\n",
        "      color_red = binary_label == 0\n",
        "      lin = binary_label == 0\n",
        "\n",
        "      color_red2 = binary_label == 0\n",
        "      lin2 = binary_label == 0\n",
        "\n",
        "      # Flip the color with a probability e that depends on the environment\n",
        "      if idx < 10000:\n",
        "        # 90% in the test environment\n",
        "        if np.random.uniform() < 0.9:\n",
        "          color_red = not color_red\n",
        "        if np.random.uniform() < 0.9:\n",
        "          lin = not lin\n",
        "        \n",
        "        if np.random.uniform() < 0.1:\n",
        "          color_red2 = not color_red2\n",
        "        if np.random.uniform() < 0.1:\n",
        "          lin2 = not lin2\n",
        "\n",
        "      im_array2 = im_array.copy()\n",
        "\n",
        "      colored_arr = color_grayscale_arr(im_array, red=color_red)\n",
        "      im_array = colored_arr.copy()\n",
        "      patch_arr = patch_img(im_array, horz=lin)\n",
        "\n",
        "      colored_arr2 = color_grayscale_arr(im_array2, red=color_red2)\n",
        "      im_array2 = colored_arr2.copy()\n",
        "      patch_arr2 = patch_img(im_array2, horz=lin2)\n",
        "      #im_array = patch_arr.copy()\n",
        "      #line_arr = line_img(im_array, horz=lin)\n",
        "\n",
        "      if idx < 10000:\n",
        "        test_set.append((Image.fromarray(colored_arr), binary_label))\n",
        "        testc_set.append((Image.fromarray(colored_arr), int(not color_red)))\n",
        "\n",
        "        test_setd.append((Image.fromarray(colored_arr2), binary_label))\n",
        "        testc_setd.append((Image.fromarray(colored_arr2), int(not color_red2)))\n",
        "\n",
        "        test_setb.append((Image.fromarray(patch_arr), binary_label))\n",
        "        testc_setb.append((Image.fromarray(patch_arr), int(1*int(not color_red)+2*int(not lin))))\n",
        "\n",
        "        test_setl.append((Image.fromarray(patch_arr2), binary_label))         #PCMNIST-iid\n",
        "        testc_setl.append((Image.fromarray(patch_arr2), int(1*int(not color_red2)+2*int(not lin2))))  #PCMNIST with 1*patch+2*line labels\n",
        "\n",
        "        test_seto.append((Image.fromarray(arr), label))\n",
        "\n",
        "      \n",
        "\n",
        "    os.makedirs(colored_mnist_dir)\n",
        "    torch.save(train1_set, os.path.join(colored_mnist_dir, 'train1.pt'))\n",
        "    torch.save(train2_set, os.path.join(colored_mnist_dir, 'train2.pt'))\n",
        "    torch.save(test_set, os.path.join(colored_mnist_dir, 'test.pt'))\n",
        "    torch.save(trainc_set, os.path.join(colored_mnist_dir, 'trainc.pt'))\n",
        "    torch.save(testc_set, os.path.join(colored_mnist_dir, 'testc.pt'))\n",
        "\n",
        "    torch.save(train_seto, os.path.join(colored_mnist_dir, 'traino.pt'))\n",
        "    torch.save(test_seto, os.path.join(colored_mnist_dir, 'testo.pt'))\n",
        "\n",
        "    torch.save(train1_setb, os.path.join(colored_mnist_dir, 'train1b.pt'))\n",
        "    torch.save(train2_setb, os.path.join(colored_mnist_dir, 'train2b.pt'))\n",
        "    torch.save(test_setb, os.path.join(colored_mnist_dir, 'testb.pt'))\n",
        "    torch.save(trainc_setb, os.path.join(colored_mnist_dir, 'traincb.pt'))\n",
        "    torch.save(testc_setb, os.path.join(colored_mnist_dir, 'testcb.pt'))\n",
        "\n",
        "    #torch.save(train1_setl, os.path.join(colored_mnist_dir, 'train1l.pt'))\n",
        "    #torch.save(train2_setl, os.path.join(colored_mnist_dir, 'train2l.pt'))\n",
        "    torch.save(test_setl, os.path.join(colored_mnist_dir, 'testl.pt'))\n",
        "    #torch.save(trainc_setl, os.path.join(colored_mnist_dir, 'traincl.pt'))\n",
        "    torch.save(testc_setl, os.path.join(colored_mnist_dir, 'testcl.pt'))\n",
        "    torch.save(test_setd, os.path.join(colored_mnist_dir, 'testd.pt'))\n",
        "    #torch.save(trainc_setl, os.path.join(colored_mnist_dir, 'traincl.pt'))\n",
        "    torch.save(testc_setd, os.path.join(colored_mnist_dir, 'testcd.pt'))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XOnjIjK8q7UJ"
      },
      "source": [
        "### Plot the data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EwUoQZCyvs6T"
      },
      "outputs": [],
      "source": [
        "def plot_dataset_digits(dataset):\n",
        "  fig = plt.figure(figsize=(13, 8))\n",
        "  columns = 6\n",
        "  rows = 3\n",
        "  # ax enables access to manipulate each of subplots\n",
        "  ax = []\n",
        "\n",
        "  for i in range(columns * rows):\n",
        "    img, label = dataset[i]\n",
        "    # create subplot and append to ax\n",
        "    ax.append(fig.add_subplot(rows, columns, i + 1))\n",
        "    ax[-1].set_title(\"Label: \" + str(label))  # set title\n",
        "    plt.imshow(img)\n",
        "\n",
        "  plt.show()  # finally, render the plot\n",
        "  \n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "11YqxmhjrSMi"
      },
      "source": [
        "Plotting the train set"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HDh4qJS_rQwK"
      },
      "outputs": [],
      "source": [
        "train1_set = ColoredMNIST(root='./data', env='train1b')\n",
        "\n",
        "plot_dataset_digits(train1_set)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LhsLWjOCrqWx"
      },
      "source": [
        "Plotting the test set"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OdmmW48srQyk"
      },
      "outputs": [],
      "source": [
        "test_set = ColoredMNIST(root='./data', env='testb')\n",
        "plot_dataset_digits(test_set)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jleIZ9vNv5rV"
      },
      "source": [
        "## Define models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9hYJRewnv80x"
      },
      "outputs": [],
      "source": [
        "class Net(nn.Module):\n",
        "  def __init__(self):\n",
        "    super(Net, self).__init__()\n",
        "    self.fc1 = nn.Linear(3 * 28 * 28, 512)\n",
        "    self.fc2 = nn.Linear(512, 512)\n",
        "    self.fc3 = nn.Linear(512, 1)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = x.view(-1, 3 * 28 * 28)\n",
        "    x = F.relu(self.fc1(x))\n",
        "    x = F.relu(self.fc2(x))\n",
        "    logits = self.fc3(x).flatten()\n",
        "    return logits\n",
        "\n",
        "class MLP(nn.Module):\n",
        "  def __init__(self):\n",
        "    super(MLP, self).__init__()\n",
        "    self.fc1 = nn.Linear(3 * 28 * 28, 390)\n",
        "    #self.fc2 = nn.Linear(512, 512)\n",
        "    self.fc3 = nn.Linear(390, 1)\n",
        "\n",
        "  def forward(self, x):\n",
        "    x = x.view(-1, 3 * 28 * 28)\n",
        "    x = F.relu(self.fc1(x))\n",
        "    #x = F.relu(self.fc2(x))\n",
        "    logits = self.fc3(x).flatten()\n",
        "    return logits\n",
        "\n",
        "class LinearReg(nn.Module):\n",
        "  def __init__(self, inputSize=28*28*3, outputSize=1):\n",
        "        super(LinearReg, self).__init__()\n",
        "        self.linear = torch.nn.Linear(inputSize, outputSize, bias = False)\n",
        "        #self.apply(self._init_weights)\n",
        "        \n",
        "  def forward(self, x):\n",
        "        x = x.view(-1, 3 * 28 * 28)\n",
        "        out = self.linear(x)\n",
        "        return out\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "il3cATYxwIyP"
      },
      "source": [
        "## ERM functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L4qZtXx_weBb"
      },
      "outputs": [],
      "source": [
        "def test_model(model, device, test_loader, set_name=\"test set\", pl=None):\n",
        "  model.eval()\n",
        "  test_loss = 0\n",
        "  correct = 0\n",
        "  print(pl)\n",
        "  with torch.no_grad():\n",
        "    for data, target in test_loader:\n",
        "      data, target = data.to(device), target.to(device).float()\n",
        "      if pl is not None:\n",
        "       if pl:\n",
        "        target = (target.int()%2).float()\n",
        "       else:\n",
        "        target = (target.int()//2).float()\n",
        "      output = model(data)\n",
        "      test_loss += F.binary_cross_entropy_with_logits(torch.squeeze(output), target, reduction='sum').item()  # sum up batch loss\n",
        "      pred = torch.where(torch.gt(output, torch.Tensor([0.0]).to(device)),\n",
        "                         torch.Tensor([1.0]).to(device),\n",
        "                         torch.Tensor([0.0]).to(device))  # get the index of the max log-probability\n",
        "      correct += pred.eq(target.view_as(pred)).sum().item()\n",
        "\n",
        "  test_loss /= len(test_loader.dataset)\n",
        "\n",
        "  print('\\nPerformance on {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n",
        "    set_name, test_loss, correct, len(test_loader.dataset),\n",
        "    100. * correct / len(test_loader.dataset)))\n",
        "\n",
        "  return 100. * correct / len(test_loader.dataset)\n",
        "\n",
        "\n",
        "def erm_train(model, device, train_loader, optimizer, epoch, pl=None, m1=None,lam=0.0):\n",
        "  model.train()\n",
        "  for batch_idx, (data, target) in enumerate(train_loader):\n",
        "    data, target = data.to(device), target.to(device).float()\n",
        "    if pl is not None:\n",
        "      if pl:\n",
        "        target = (target.int()%2).float()\n",
        "      else:\n",
        "        target = (target.int()//2).float()\n",
        "    optimizer.zero_grad()\n",
        "    output = model(data)\n",
        "    if lam==0.0:\n",
        "      loss = F.binary_cross_entropy_with_logits(torch.squeeze(output), target) \n",
        "    else:\n",
        "      loss = F.binary_cross_entropy_with_logits(torch.squeeze(output), target)+ lam*reg_loss(m1,model)\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    if batch_idx % 200 == 0:\n",
        "      print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
        "        epoch, batch_idx * len(data), len(train_loader.dataset),\n",
        "               100. * batch_idx / len(train_loader), loss.item()))\n",
        "\n",
        "\n",
        "def train_and_test_erm(ft=True,cnn=True,col=True,org=False,lin=False,pl=None,lr=0.01,wtd=0.05,opt=False,epochs=5):\n",
        "  use_cuda = torch.cuda.is_available()\n",
        "  device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
        "\n",
        "  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
        "  #pl=None\n",
        "  torch.manual_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  random.seed(seed)\n",
        "  if org:\n",
        "    env1='traino'\n",
        "    env2='testo'\n",
        "  else:\n",
        "   if col: #col=False means Patch\n",
        "    if ft: #ft=True means use digit labels, else use sp ft labels\n",
        "      env1='all_train'\n",
        "      env2='test'\n",
        "    else:\n",
        "      env1='trainc'\n",
        "      env2='testc'\n",
        "   else:\n",
        "    if not lin: #lin=True means line is present in addition to patch\n",
        "     if ft:\n",
        "      env1='all_trainb'\n",
        "      env2='testb'\n",
        "     else:\n",
        "      env1='traincb'\n",
        "      env2='testcb'\n",
        "    else:\n",
        "     if ft:\n",
        "      env1='all_trainl'\n",
        "      env2='testl'\n",
        "     else:\n",
        "      env1='traincl'\n",
        "      env2='testcl'\n",
        "      #pl=pl #when ft is False, use pl=True to train erm for patch, and False for line\n",
        "  print(env1,env2,pl)   \n",
        "  if cnn:\n",
        "    model = MLP().to(device)\n",
        "  else:\n",
        "    model = LinearReg().to(device)\n",
        "\n",
        "  all_train_loader = torch.utils.data.DataLoader(\n",
        "    ColoredMNIST(root='./data', env=env1,\n",
        "                 transform=transforms.Compose([\n",
        "                     transforms.ToTensor(),\n",
        "                     transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))\n",
        "                   ])),\n",
        "    batch_size=64, shuffle=True, **kwargs)\n",
        "\n",
        "  test_loader = torch.utils.data.DataLoader(\n",
        "    ColoredMNIST(root='./data', env=env2, transform=transforms.Compose([\n",
        "      transforms.ToTensor(),\n",
        "      transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))\n",
        "    ])),\n",
        "    batch_size=1000, shuffle=True, **kwargs)\n",
        "\n",
        "  \n",
        "  if opt:\n",
        "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
        "  else:\n",
        "    optimizer = optim.SGD(model.parameters(), lr=lr,weight_decay=wtd)\n",
        "\n",
        "  for epoch in range(1, epochs):\n",
        "    erm_train(model, device, all_train_loader, optimizer, epoch, pl)\n",
        "    test_model(model, device, all_train_loader, set_name='train set',pl=pl)\n",
        "    acc=test_model(model, device, test_loader,pl=pl)\n",
        "    if acc>=100-1e-6:\n",
        "      return model\n",
        "  \n",
        "  return model\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PRAXf7ozQYdJ"
      },
      "source": [
        "#Mutual Information Computations"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gFz8fq5cQbl1"
      },
      "outputs": [],
      "source": [
        "## Information Computations\n",
        "# Base code from https://github.com/anoneurips2019/SGD-learns-functions-of-increasing-complexity\n",
        "\n",
        "# estimate I(M(x); y | L(x))\n",
        "\n",
        "def kernel_val(x, mu, sig):\n",
        "    a=torch.div(torch.exp(torch.div(-torch.square((x-mu)), 2*torch.square(torch.tensor(sig)))), torch.sqrt(torch.tensor(2*torch.pi)))\n",
        "    return a\n",
        "\n",
        "def getbinvals(M, cen_arr, Xv, Yv, device, sig=0.05):\n",
        "    X=torch.zeros((4)).to(device)\n",
        "    Y=torch.zeros((4)).to(device)\n",
        "    for i in range(4):\n",
        "    #for j in range(4):\n",
        "        X[i]=kernel_val(Xv, cen_arr[i], sig)\n",
        "        Y[i]=kernel_val(Yv, cen_arr[i], sig)\n",
        "    M = torch.outer(torch.div(X, torch.sum(X)), torch.div(Y, torch.sum(Y)))\n",
        "    return M\n",
        "    \n",
        "\n",
        "def est_IMYL(model, linModel, x, y, device, round=True, bin=False):\n",
        "    Mx, Lx, y = est_pred(model, linModel, x, y)\n",
        "    p = est_density(Mx, y, Lx, device, round, bin)\n",
        "    return I_XYZ(p)\n",
        "\n",
        "def est_IMLY(model, linModel, x, y, device, round=True, bin=False):\n",
        "    Mx, Lx, y = est_pred(model, linModel, x, y)\n",
        "    p = est_density(Mx, Lx, y, device, round, bin)\n",
        "    return I_XYZ(p)\n",
        "\n",
        "def est_pred(model, linModel, x, y):\n",
        "    o1 = model(x)\n",
        "    o1 = torch.sigmoid(o1)\n",
        "\n",
        "    o2 = linModel(x)\n",
        "    o2 = torch.squeeze(torch.sigmoid(o2))\n",
        "    Mx = torch.flatten(o1)\n",
        "    Lx = torch.flatten(o2)\n",
        "    y = torch.flatten(y)\n",
        "    \n",
        "    #p = est_density(Mx, y, Lx)\n",
        "    return Mx, Lx, y\n",
        "    \n",
        "def est_density(X, Y, Z, device, round=True, bin=False): # estimate p[x,y,z] \\in R^{{0,1}^3} for samples from X, Y, Z \\in \\N\n",
        "    \n",
        "    if not bin: \n",
        "     if round: \n",
        "      X = torch.round(X).int()\n",
        "      Y = torch.round(Y).int()\n",
        "     else:\n",
        "      X = torch.sigmoid(12.5*(X-0.5))\n",
        "      Y = torch.sigmoid(12.5*(Y-0.5))\n",
        "    Z = torch.round(Z).int()\n",
        "    \n",
        "    n = X.size(dim=0)\n",
        "    if bin:\n",
        "      p = torch.zeros((4,4,2)).to(device)\n",
        "      cen_arr = torch.tensor([0.125, 0.375, 0.625, 0.875])\n",
        "    else:\n",
        "      p = torch.zeros((2, 2, 2)).to(device) # p[x,y,z] is the joint prob density\n",
        "    \n",
        "    flag=True\n",
        "    for i in range(n):\n",
        "      #if i==1:\n",
        "      #  flag=False\n",
        "      if bin:\n",
        "        if round:\n",
        "         #i1 = (torch.round(2*X[i])+torch.round(X[i])).int()\n",
        "         i1 = (1*(X[i]>0.25)+1*(X[i]>0.5)+1*(X[i]>0.75)).int()\n",
        "         i2 = (1*(Y[i]>0.25)+1*(Y[i]>0.5)+1*(Y[i]>0.75)).int()\n",
        "         #i2 = (torch.round(2*Y[i])+torch.round(Y[i])).int()\n",
        "         p[i1, i2, Z[i]]+=1\n",
        "        else:\n",
        "         M = torch.zeros((4,4,2)).to(device)\n",
        "         M = getbinvals(M, cen_arr, X[i], Y[i], device)\n",
        "         #if flag:\n",
        "         #  print(M)\n",
        "         p[:,:,Z[i]] += M #/torch.sum(M, axis=(0,1))\n",
        "      else:\n",
        "        p[0, 0, Z[i]] += (1.0-X[i])*(1.0-Y[i])\n",
        "        p[0, 1, Z[i]] += (1.0-X[i])*(Y[i])\n",
        "        p[1, 0, Z[i]] += (X[i])*(1.0-Y[i])\n",
        "        p[1, 1, Z[i]] += X[i]*Y[i]\n",
        "    \n",
        "    p /= n\n",
        "    if bin and flag:\n",
        "      print(p)\n",
        "    return p\n",
        "\n",
        "def I_XYZ(p): # compute I(X, Y | Z) for joint density p[x, y, z]\n",
        "    pz = torch.sum(p, axis=(0,1), keepdims=True) # the density of z. pz[x,y,z] = p(z)\n",
        "    \n",
        "    p_xy_z = p / pz  # q[x, y, z] = p(x, y | z)\n",
        "    p_x_z =  torch.sum(p, axis=1, keepdims=True) / pz  # p(x | z)\n",
        "    p_y_z =  torch.sum(p, axis=0, keepdims=True) / pz  # p(y | z)\n",
        "    \n",
        "    I = torch.sum(p * torch.nan_to_num(torch.log2( p_xy_z / (p_x_z * p_y_z) )))\n",
        "    return I\n",
        "\n",
        "\n",
        "# returns I(A; B) where A, B \\in {X, Y, Z} spefice by idx\n",
        "# eg, I(X; Y) = I_ab(p, idx=[0, 1])\n",
        "#     I(X; Z) = I_ab(p, idx=[0, 2])\n",
        "def I_ab(p, idx=(0,1)): \n",
        "    exlude = (0+1+2) - np.sum(idx)\n",
        "    p_ab = torch.sum(p, axis=exlude)\n",
        "    p_a = torch.sum(p_ab, axis=1, keepdims=True)\n",
        "    p_b = torch.sum(p_ab, axis=0, keepdims=True)\n",
        "    \n",
        "    I = torch.sum(p_ab * torch.nan_to_num(torch.log2( p_ab / (p_a * p_b) )))\n",
        "    return I\n",
        "\n",
        "def H(q): # binary entropy\n",
        "    return -q*torch.log2(q) - (1-q)*torch.log2(1-q)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZMFqwKY7tjG_"
      },
      "source": [
        "# CMID functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hGHo7x7PtieR"
      },
      "outputs": [],
      "source": [
        "def mi_reg(model, lin, x_train, y_train,device):\n",
        "  I_MLY = est_IMLY(model, lin, x_train, y_train,device,round=False)  # p[M, Y, L]\n",
        "  #I_MYL = I_XYZ(pMYL)\n",
        "  return torch.abs(I_MLY)\n",
        "\n",
        "def erm_train_mi_reg(model, device, train_loader, optimizer, epoch,m1=None,m2=None,lam=0.0):\n",
        "  model.train()\n",
        "  for batch_idx, (data, target) in enumerate(train_loader):\n",
        "    data, target = data.to(device), target.to(device).float()\n",
        "    optimizer.zero_grad()\n",
        "    output = model(data)\n",
        "    \n",
        "    if lam!=0:\n",
        "      l1 = mi_reg(model,m1,data,target,device)\n",
        "    else: \n",
        "      l1 = torch.squeeze(torch.tensor([0]).to(device))\n",
        "    if m2 is not None:\n",
        "      l2 = mi_reg(model,m2,data,target,device)\n",
        "    else:\n",
        "      l2 = torch.squeeze(torch.tensor([0]).to(device))\n",
        "    l1+=l2\n",
        "    if lam==-1:\n",
        "      lam=min(epoch**1.5/15,5)\n",
        "    loss = F.binary_cross_entropy_with_logits(torch.squeeze(output), target)+ lam*l1\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    if batch_idx % 100 == 0:\n",
        "      print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
        "        epoch, batch_idx * len(data), len(train_loader.dataset),\n",
        "               100. * batch_idx / len(train_loader), loss.item()))\n",
        "\n",
        "\n",
        "def train_and_test_erm_reg(ft=True,cnn=True,col=True,lin=False,m1=None,m2=None,lam=0,epochs=2,bs=64, model=None,cnn2=False,lr=0.01, scale=100, id=False,opt=False):\n",
        "  use_cuda = torch.cuda.is_available()\n",
        "  device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
        "\n",
        "  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
        "  torch.manual_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  random.seed(seed)\n",
        "  if col:\n",
        "    if ft:\n",
        "      env1='all_train'\n",
        "      env2='test'\n",
        "    else:\n",
        "      env1='trainc'\n",
        "      env2='testc'\n",
        "  else:\n",
        "    if not lin: #lin=True means line is present in addition to patch\n",
        "     if ft:\n",
        "      env1='all_trainb'\n",
        "      env2='testb'\n",
        "     else:\n",
        "      env1='traincb'\n",
        "      env2='testcb'\n",
        "    else:\n",
        "     if ft:\n",
        "      env1='all_trainl'\n",
        "      env2='testl'\n",
        "     else:\n",
        "      env1='traincl'\n",
        "      env2='testcl'\n",
        "      #pl=pl #when ft is False, use pl=True to train erm for patch, and False for line\n",
        "  if model is None:\n",
        "    if cnn:\n",
        "      if col or cnn2:\n",
        "        model = MLP().to(device)\n",
        "      else:\n",
        "        model = MLP().to(device)\n",
        "    else:\n",
        "      model = LinearReg().to(device)\n",
        "\n",
        "  all_train_loader = torch.utils.data.DataLoader(\n",
        "    ColoredMNIST(root='./data', env=env1,\n",
        "                 transform=transforms.Compose([\n",
        "                     transforms.ToTensor(),\n",
        "                     transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))\n",
        "                   ])),\n",
        "    batch_size=bs, shuffle=True, **kwargs)\n",
        "\n",
        "  test_loader = torch.utils.data.DataLoader(\n",
        "    ColoredMNIST(root='./data', env=env2, transform=transforms.Compose([\n",
        "      transforms.ToTensor(),\n",
        "      transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))\n",
        "    ])),\n",
        "    batch_size=1000, shuffle=True, **kwargs)\n",
        "  \n",
        "  if id and col:\n",
        "    test_loader2 = torch.utils.data.DataLoader(\n",
        "    ColoredMNIST(root='./data', env='testd', transform=transforms.Compose([\n",
        "      transforms.ToTensor(),\n",
        "      transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))\n",
        "    ])),\n",
        "    batch_size=1000, shuffle=True, **kwargs)\n",
        "  elif id:\n",
        "    test_loader2 = torch.utils.data.DataLoader(\n",
        "    ColoredMNIST(root='./data', env='testl', transform=transforms.Compose([\n",
        "      transforms.ToTensor(),\n",
        "      transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))\n",
        "    ])),\n",
        "    batch_size=1000, shuffle=True, **kwargs)\n",
        "\n",
        "  if opt:\n",
        "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
        "  else:\n",
        "    optimizer = optim.SGD(model.parameters(), lr=lr)\n",
        "\n",
        "  for epoch in range(1, epochs):\n",
        "    lam1=lam*(1+epoch/scale)\n",
        "    erm_train_mi_reg(model, device, all_train_loader, optimizer, epoch,m1,m2,lam1)\n",
        "    test_model(model, device, all_train_loader, set_name='train set')\n",
        "    if id:\n",
        "      test_model(model, device, test_loader2)\n",
        "    test_model(model, device, test_loader)\n",
        "    \n",
        "  \n",
        "  return model\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Train Simple Models"
      ],
      "metadata": {
        "id": "jjw8L-aX2VHE"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_3eLUmMsgye2"
      },
      "outputs": [],
      "source": [
        "lp_inv=train_and_test_erm(ft=True,cnn=False,col=True,lr=0.01,epochs=5,wtd=0.005) #linear clf for digit clfn  on cmnist\n",
        "lpl_inv=train_and_test_erm(ft=True,cnn=False,col=False,lr=0.01,epochs=5,wtd=0.001) #linear clf for digit clfn on data with both color and patch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3HBW9D5TBrue"
      },
      "source": [
        "# CMID on CMNIST"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vzGP0s6kI0uC"
      },
      "outputs": [],
      "source": [
        "wregc=train_and_test_erm_reg(ft=True,cnn=True,col=True,m1=lp_inv,lam=4,epochs=20,lr=0.001,scale=4, id=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j4dkeahdCtFu"
      },
      "source": [
        "# CMID on CPMNIST"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PIbNCOnBgmiB"
      },
      "outputs": [],
      "source": [
        "wregc=train_and_test_erm_reg(ft=True,cnn=True,col=False,m1=lpl_inv,lam=5,epochs=20,lr=0.005,scale=3, id=True)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "kkTMK-oJveGg",
        "jleIZ9vNv5rV",
        "PRAXf7ozQYdJ",
        "ux9V7mPdIaEY",
        "AwNXBlWoKjzp",
        "qvOrZQ2XmD2I",
        "d9r5uMWWZAuL",
        "_YLmxmeTZGTn",
        "QE1rz8QeZCKp",
        "oTQukPxOjPPn"
      ],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}