{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import pandas\n",
    "import torch\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "learning_rate = 1e-3\n",
    "epochs = 20\n",
    "\n",
    "from torch import nn\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import random\n",
    "from torch import nn\n",
    "import math\n",
    "import pandas\n",
    "import torch\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "from torchvision import models\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch import optim\n",
    "from torch import nn\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.datasets as dataset\n",
    "import torchvision.transforms as transforms\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "data": {
      "text/plain": "     image_id  5_o_Clock_Shadow  Arched_Eyebrows  Attractive  Bags_Under_Eyes  \\\n0  000001.jpg                 0                1           1                0   \n1  000002.jpg                 0                0           0                1   \n2  000003.jpg                 0                0           0                0   \n3  000004.jpg                 0                0           1                0   \n4  000005.jpg                 0                1           1                0   \n5  000006.jpg                 0                1           1                0   \n6  000007.jpg                 1                0           1                1   \n7  000008.jpg                 1                1           0                1   \n8  000009.jpg                 0                1           1                0   \n9  000010.jpg                 0                0           1                0   \n\n   Bald  Bangs  Big_Lips  Big_Nose  Black_Hair  ...  Sideburns  Smiling  \\\n0     0      0         0         0           0  ...          0        1   \n1     0      0         0         1           0  ...          0        1   \n2     0      0         1         0           0  ...          0        0   \n3     0      0         0         0           0  ...          0        0   \n4     0      0         1         0           0  ...          0        0   \n5     0      0         1         0           0  ...          0        0   \n6     0      0         1         1           1  ...          0        0   \n7     0      0         1         0           1  ...          0        0   \n8     0      1         1         0           0  ...          0        1   \n9     0      0         0         0           0  ...          0        0   \n\n   Straight_Hair  Wavy_Hair  Wearing_Earrings  Wearing_Hat  Wearing_Lipstick  \\\n0              1          0                 1            0                 1   \n1              0          0                 0            0                 0   \n2              0          1                 0            0                 0   \n3              1          0                 1            0                 1   \n4              0          0                 0            0                 1   \n5              0          1                 1            0                 1   \n6              1          0                 0            0                 0   \n7              0          0                 0            0                 0   \n8              0          0                 1            0                 1   \n9              0          1                 0            0                 1   \n\n   Wearing_Necklace  Wearing_Necktie  Young  \n0                 0                0      1  \n1                 0                0      1  \n2                 0                0      1  \n3                 1                0      1  \n4                 0                0      1  \n5                 0                0      1  \n6                 0                0      1  \n7                 0                0      1  \n8                 0                0      1  \n9                 0                0      1  \n\n[10 rows x 41 columns]",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>image_id</th>\n      <th>5_o_Clock_Shadow</th>\n      <th>Arched_Eyebrows</th>\n      <th>Attractive</th>\n      <th>Bags_Under_Eyes</th>\n      <th>Bald</th>\n      <th>Bangs</th>\n      <th>Big_Lips</th>\n      <th>Big_Nose</th>\n      <th>Black_Hair</th>\n      <th>...</th>\n      <th>Sideburns</th>\n      <th>Smiling</th>\n      <th>Straight_Hair</th>\n      <th>Wavy_Hair</th>\n      <th>Wearing_Earrings</th>\n      <th>Wearing_Hat</th>\n      <th>Wearing_Lipstick</th>\n      <th>Wearing_Necklace</th>\n      <th>Wearing_Necktie</th>\n      <th>Young</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>000001.jpg</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>...</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>000002.jpg</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>...</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>000003.jpg</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>...</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>000004.jpg</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>...</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>000005.jpg</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>...</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>5</th>\n      <td>000006.jpg</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>...</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>6</th>\n      <td>000007.jpg</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>1</td>\n      <td>...</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>7</th>\n      <td>000008.jpg</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>...</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>8</th>\n      <td>000009.jpg</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>...</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>9</th>\n      <td>000010.jpg</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>...</td>\n      <td>0</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n      <td>0</td>\n      <td>0</td>\n      <td>1</td>\n    </tr>\n  </tbody>\n</table>\n<p>10 rows × 41 columns</p>\n</div>"
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "root = \"/root/Datasets/CelebA/celeba/\"\n",
    "df = pandas.read_csv(f'{root}list_attr_celeba.csv')\n",
    "df = df.replace([-1], 0)\n",
    "\n",
    "# def marginal(df):\n",
    "#     print(df.value_counts(ascending=True,  normalize=True).reset_index(name='count'))\n",
    "\n",
    "def joint(df):\n",
    "\t# print(df.value_counts(ascending=True,  normalize=True).reset_index(name='count'))\n",
    "\treturn df.value_counts(ascending=True,  normalize=True).reset_index(name='count')\n",
    "\n",
    "\n",
    "df[0:10]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "all_labels= ['image_id', 'Male', 'Young']\n",
    "df = df[all_labels]\n",
    "values = [(1,1), (1,0), (0,1), (0,0)]\n",
    "\n",
    "num_samples =30000\n",
    "\n",
    "all_rows=[]\n",
    "for val in values:\n",
    "\tselected= df.index[(df['Male']==val[0])  & (df['Young']==val[1])].tolist()\n",
    "\tl = len(selected)\n",
    "\tmid= int(l/2)\n",
    "\ti1_df=  df.iloc[selected[0:mid]]\n",
    "\ti2_df=  df.iloc[selected[mid: 2*mid]]\n",
    "\n",
    "\ti1_df.reset_index(drop=True, inplace=True)\n",
    "\ti2_df.reset_index(drop=True, inplace=True)\n",
    "\n",
    "\ti1_df = i1_df.rename(columns={'image_id': 'image_id1'})\n",
    "\ti2_df = i2_df.rename(columns={'image_id': 'image_id2'})\n",
    "\n",
    "\tdel i2_df['Male']\n",
    "\tdel i2_df['Young']\n",
    "\n",
    "\trow= pandas.concat([i1_df, i2_df], axis=1)\n",
    "\n",
    "\n",
    "\tall_rows.append(row)\n",
    "\n",
    "my_dataset= pandas.concat(all_rows, axis=0)\n",
    "shuffled_dataset = my_dataset.sample(frac=1, random_state=1)\n",
    "shuffled_dataset= shuffled_dataset.reset_index(drop=True)\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "data": {
      "text/plain": "30000"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "shuffled_dataset = shuffled_dataset[0: num_samples]\n",
    "len(shuffled_dataset)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saved at : /root/PycharmProjects/IDGEN/CelebaExperiment/base_data/frontdoor_celeba.csv\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "folder= '/root/PycharmProjects/IDGEN/CelebaExperiment/base_data/'\n",
    "os.makedirs(folder, exist_ok=True)\n",
    "file_name= folder+'frontdoor_celeba.csv'\n",
    "shuffled_dataset.to_csv(file_name, encoding='utf-8', index=False)\n",
    "print('saved at :',file_name)\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([30000])"
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_dataset = {}\n",
    "true_dataset['Sex'] = torch.tensor(shuffled_dataset['Male'].values)\n",
    "true_dataset['Age'] = torch.tensor(shuffled_dataset['Male'].values)\n",
    "true_dataset['Sex'].shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "\n",
    "image_size=64\n",
    "transform = transforms.Compose(\n",
    "\t\t\t[transforms.Resize((image_size,image_size)),\n",
    "\t\t\t transforms.ToTensor(),\n",
    "\t\t\t transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 30000/30000 [00:15<00:00, 1985.73it/s]\n"
     ]
    }
   ],
   "source": [
    "img_folder= \"/root/Datasets/CelebA/celeba/img_align_celeba\"\n",
    "images =[]\n",
    "for id in tqdm(shuffled_dataset['image_id1']):\n",
    "\timg_path = f'{img_folder}/{id}'\n",
    "\tcur_im = Image.open(img_path)\n",
    "\timage_tensor = transform(cur_im)\n",
    "\timages.append(image_tensor.unsqueeze(0))\n",
    "\n",
    "true_dataset['I1'] = torch.cat(images)\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 30000/30000 [00:14<00:00, 2089.59it/s]\n"
     ]
    }
   ],
   "source": [
    "img_folder= \"/root/Datasets/CelebA/celeba/img_align_celeba\"\n",
    "images =[]\n",
    "for id in tqdm(shuffled_dataset['image_id2']):\n",
    "\timg_path = f'{img_folder}/{id}'\n",
    "\tcur_im = Image.open(img_path)\n",
    "\timage_tensor = transform(cur_im)\n",
    "\timages.append(image_tensor.unsqueeze(0))\n",
    "\n",
    "true_dataset['I2'] = torch.cat(images)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "data": {
      "text/plain": "dict_keys(['Sex', 'Age', 'I1', 'I2'])"
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "true_dataset.keys()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "import pickle\n",
    "save_dir ='/root/PycharmProjects/IDGEN/CelebaExperiment/base_data/'\n",
    "\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "save_loc = f'{save_dir}ground_truth.pkl'\n",
    "with open(save_loc, 'wb') as f:\n",
    "\tpickle.dump(true_dataset, f)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Solution design"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# P(y|do(x)) = \\sum_z P(z|x)  \\sum_x' P(y|x',z)\n",
    "# P(i2|do(i1)) = \\sum_{sex} P(sex|i1) * \\sum_{i1'} P(i2| i1', sex)\n",
    "# IDGEN: Step 1: Z~ P(Z|I1)    Step2.1: Sex~U[0,1]; I1 ~ P(I1) ; I2~ P(I2|I1, Sex);  Step2.2: I2~ P'(I2|Sex)  ; Step 3:  I1 -> Sex -> I2\n",
    "# Need to train: 1. A classifier: P(Z|I1) 2. A diffusion model that takes dataset I_1 and uniform Sex as input and generates I2.  3. A diffusion model that takes Sex as input and generates I2 trained on new data."
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
