{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# config\n",
    "\n",
    "dataset = 'fmnist' # mnist, celeba, fmnist\n",
    "print(\"Dataset: \", dataset)\n",
    "\n",
    "# load real data\n",
    "if dataset == 'mnist':\n",
    "    unbiased_x = torch.load('/home/.../nas/PF-GAN/dataset/rotated/mnist_31_Unbiased/train_data.pt') / 1.0\n",
    "    unbiased_y = torch.load('/home/.../nas/PF-GAN/dataset/rotated/mnist_31_Unbiased/train_Y.pt')\n",
    "    unbiased_A = torch.load('/home/.../nas/PF-GAN/dataset/rotated/mnist_31_Unbiased/train_A.pt')\n",
    "    group_to_digit = {'minor': 1, 'major': 3}\n",
    "    digit_to_group = {1: 'minor', 3: 'major'}\n",
    "    \n",
    "elif dataset == 'fmnist':\n",
    "    unbiased_x = torch.load('/home/.../nas/PF-GAN/dataset/rotated/fmnist_71_Unbiased/train_data.pt') / 1.0\n",
    "    unbiased_y = torch.load('/home/.../nas/PF-GAN/dataset/rotated/fmnist_71_Unbiased/train_Y.pt')\n",
    "    unbiased_A = torch.load('/home/.../nas/PF-GAN/dataset/rotated/fmnist_71_Unbiased/train_A.pt')\n",
    "    group_to_digit = {'minor': 1, 'major': 7}\n",
    "    digit_to_group = {1: 'minor', 7: 'major'}\n",
    "\n",
    "\n",
    "# data group\n",
    "cln_1 = unbiased_x[(unbiased_y == group_to_digit['minor']) & (unbiased_A == 1)]\n",
    "rot_1 = unbiased_x[(unbiased_y == group_to_digit['minor']) & (unbiased_A == 0)]\n",
    "cln_3 = unbiased_x[(unbiased_y == group_to_digit['major']) & (unbiased_A == 1)]\n",
    "rot_3 = unbiased_x[(unbiased_y == group_to_digit['major']) & (unbiased_A == 0)]\n",
    "\n",
    "\n",
    "# get centroids\n",
    "centroid_dict = {}\n",
    "centroid_dict['cln_minor'] = torch.mean(cln_1.view(cln_1.size(0), -1), dim = 0)\n",
    "centroid_dict['rot_minor'] = torch.mean(rot_1.view(rot_1.size(0), -1), dim = 0)\n",
    "centroid_dict['cln_major'] = torch.mean(cln_3.view(cln_3.size(0), -1), dim = 0)\n",
    "centroid_dict['rot_major'] = torch.mean(rot_3.view(rot_3.size(0), -1), dim = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load gen data\n",
    "\n",
    "import joblib\n",
    "import os\n",
    "\n",
    "# config ===\n",
    "gen_data_path = '/home/.../pfgan_hub/DataLens/impsir_fmnist_eps10_batch17_teacher4_try2/eps-9.99.data'\n",
    "target_model = 'gpate'\n",
    "# ===\n",
    "save_dir = os.path.dirname(gen_data_path)\n",
    "savename = os.path.basename(gen_data_path) + '_labeled'\n",
    "\n",
    "\n",
    "# G-PATE ver.\n",
    "if target_model == 'gpate':\n",
    "    gen_data = joblib.load(gen_data_path)\n",
    "    gen_data_x = gen_data[:, :-10] * 255\n",
    "    gen_data_y = np.argmax(gen_data[:, -10:], axis=1)\n",
    "\n",
    "    gen_data_x = torch.tensor(gen_data_x).float()\n",
    "    gen_data_y = torch.tensor(gen_data_y).long()\n",
    "\n",
    "elif target_model == 'datalens':\n",
    "    gen_data_x = torch.tensor([])\n",
    "    gen_data_y = torch.tensor([])\n",
    "    for i in range(10):\n",
    "        gen_data = joblib.load(gen_data_path)\n",
    "        curr_gen_data_x = gen_data[:, :-10] * 255\n",
    "        curr_gen_data_y = np.argmax(gen_data[:, -10:], axis=1)\n",
    "\n",
    "        curr_gen_data_x = torch.tensor(curr_gen_data_x).float()\n",
    "        curr_gen_data_y = torch.tensor(curr_gen_data_y)\n",
    "\n",
    "        gen_data_x = torch.cat([gen_data_x, curr_gen_data_x], dim=0)\n",
    "        gen_data_y = torch.cat([gen_data_y, curr_gen_data_y], dim=0)\n",
    "\n",
    "        gen_data_path = gen_data_path.replace(f\"-{i}.pkl\", f\"-{i+1}.pkl\")\n",
    "\n",
    "else:\n",
    "    gen_data = np.load(gen_data_path)['data']\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_data_y.unique(return_counts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = np.where(gen_data_y != 0)[0]\n",
    "\n",
    "gen_data_x = gen_data_x[indices]\n",
    "gen_data_y = gen_data_y[indices]\n",
    "\n",
    "gen_data_y.unique(return_counts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get z label and save data\n",
    "\n",
    "# config ==\n",
    "save_z = True\n",
    "# ===\n",
    "\n",
    "z = np.zeros_like(gen_data_y)\n",
    "\n",
    "dist_dict = {}\n",
    "# find closest centroid\n",
    "for idx in range(0, gen_data_x.shape[0]):\n",
    "    sample = gen_data_x[idx]\n",
    "    \n",
    "    for k, v in centroid_dict.items():\n",
    "        if digit_to_group[int(gen_data_y[idx].item())] not in k:\n",
    "            continue\n",
    "        else:\n",
    "            dist = torch.dist(sample, v, 2)\n",
    "            dist_dict[k] = dist\n",
    "\n",
    "    pred = min(dist_dict, key=dist_dict.get)\n",
    "\n",
    "    if 'cln' in pred:\n",
    "        z[idx] = 1\n",
    "    \n",
    "print(np.unique(z, return_counts=True))\n",
    "with open(os.path.join(save_dir, savename + '_z.txt'), 'w') as f:\n",
    "    f.write(str(np.unique(z, return_counts=True)))\n",
    "\n",
    "# save data to npz\n",
    "if save_z:\n",
    "    np.savez(os.path.join(save_dir, savename), data_z = z, data_x =gen_data_x, data_y =gen_data_y)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot test with z\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "def plot_img(sample):\n",
    "    plt.imshow(sample.view(28, 28).cpu().numpy(), cmap='gray')\n",
    "    plt.show()\n",
    "    \n",
    "\n",
    "num_samples = 15\n",
    "for idx in np.random.choice(gen_data_x.size(0), num_samples):\n",
    "    plot_img(gen_data_x[idx])\n",
    "    print(f'label: {gen_data_y[idx]}')\n",
    "    print(f'z: {z[idx]}')\n",
    "    print(\"=====================================\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## test with hand labels\n",
    "\n",
    "#  load test data\n",
    "test_datapath = '/home/.../nas/PF-GAN/models/GS-WGAN/results/fmnist/main/cond_A_sir_big/seed_2/gen_data'\n",
    "test_data_x = torch.load(test_datapath + '/gen_data_400_img.pt')\n",
    "test_data_x = test_data_x.view(-1, 784) * 255\n",
    "test_data_y = torch.load(test_datapath + '/gen_data_400_label.pt')\n",
    "test_data_Z = torch.load(test_datapath + '/gen_data_400_z.pt')\n",
    "\n",
    "\n",
    "\n",
    "# find closest centroid\n",
    "cnt = 0\n",
    "for idx in range(0, test_data_x.shape[0]):\n",
    "    sample = test_data_x[idx]\n",
    "    y = test_data_y[idx]\n",
    "\n",
    "    dist_dict = {}\n",
    "    for k, v in centroid_dict.items():\n",
    "        if digit_to_group[int(y.item())] not in k:\n",
    "            continue\n",
    "        else:\n",
    "            dist = torch.dist(sample, v, 2)\n",
    "            dist_dict[k] = dist\n",
    "\n",
    "    pred = min(dist_dict, key=dist_dict.get)\n",
    "    z = 1 if 'cln' in pred else 0\n",
    "\n",
    "    # if idx < 10:\n",
    "    #     plot_img(sample)\n",
    "    #     print(z, test_data_Z[idx])\n",
    "    #     print(\"=====\")\n",
    "    \n",
    "\n",
    "    if z == test_data_Z[idx]:\n",
    "        cnt += 1\n",
    "\n",
    "print(\"Acc: \", cnt/len(test_data_y) * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reliability test (Real to Real)\n",
    "\n",
    "\n",
    "cnt = 0\n",
    "\n",
    "# find closest centroid\n",
    "for idx in range(0, unbiased_x.shape[0]):\n",
    "    sample = unbiased_x[idx].view(-1, 1).squeeze()\n",
    "\n",
    "    dist_list = {}\n",
    "    for k, v in centroid_dict.items():\n",
    "        if digit_to_group[int(unbiased_y[idx])] not in k:\n",
    "            continue\n",
    "        else:\n",
    "            dist = torch.dist(sample, v, 2)\n",
    "            dist_dict[k] = dist\n",
    "\n",
    "    pred = min(dist_list, key=dist_list.get)\n",
    "\n",
    "    A_pred = 1 if \"cln\" in pred else 0\n",
    "    y_pred = 3 if \"3\" in pred else 1\n",
    "\n",
    "    if y_pred == int(unbiased_y[idx].item()) and A_pred == int(unbiased_A[idx].item()):\n",
    "        cnt+=1\n",
    "\n",
    "print(\"Acc: \", cnt / unbiased_x.shape[0] * 100)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Another approach: KMeans\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "class KMeans:\n",
    "    def __init__(self, n_clusters, max_iter=300):\n",
    "        self.n_clusters = n_clusters\n",
    "        self.max_iter = max_iter\n",
    "\n",
    "    def fit(self, X):\n",
    "        # Randomly initialize centroids\n",
    "        self.centroids = X[np.random.choice(X.shape[0], self.n_clusters, replace=False)]\n",
    "        \n",
    "        for _ in range(self.max_iter):\n",
    "            # Assign each data point to the nearest centroid\n",
    "            labels = self._assign_clusters(X)\n",
    "            \n",
    "            # Update centroids based on the mean of data points in each cluster\n",
    "            new_centroids = np.array([X[labels == k].mean(axis=0) for k in range(self.n_clusters)])\n",
    "            \n",
    "            # Check for convergence\n",
    "            if np.allclose(self.centroids, new_centroids):\n",
    "                break\n",
    "            \n",
    "            self.centroids = new_centroids\n",
    "        \n",
    "        return self\n",
    "\n",
    "    def _assign_clusters(self, X):\n",
    "        # Calculate L2 distance between each data point and each centroid\n",
    "        distances = np.sqrt(((X[:, np.newaxis] - self.centroids)**2).sum(axis=2))\n",
    "        \n",
    "        # Assign each data point to the nearest centroid\n",
    "        labels = np.argmin(distances, axis=1)\n",
    "        \n",
    "        return labels\n",
    "\n",
    "    \n",
    "unbiased_x = unbiased_x.view(-1, 784).numpy()\n",
    "\n",
    "# Initialize KMeans with desired number of clusters\n",
    "kmeans = KMeans(n_clusters=4)\n",
    "\n",
    "# Fit KMeans to the data\n",
    "kmeans.fit(unbiased_x)\n",
    "\n",
    "# Get cluster labels for each data point\n",
    "labels = kmeans._assign_clusters(unbiased_x)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pfgan",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
