{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "eea95ae0-a730-4745-9d21-dffe683a20c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "import random\n",
    "import seaborn as sns\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.cluster import KMeans\n",
    "import os\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92a6fd2a",
   "metadata": {},
   "source": [
    "# Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d77559b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_dots_pic(ax, data, target, color, classes=10):\n",
    "    markers = ['o', 's', 'D', 'v', '^', 'p', '*', 'H', 'd', 'X']\n",
    "    for i in range(classes):\n",
    "        data_i = data[target == i]\n",
    "        ax.scatter(data_i[:, 0], data_i[:, 1], marker=markers[i], label=f'{i}', color=color, alpha=0.5)\n",
    "\n",
    "\n",
    "\n",
    "def generate_dots_pic_3d(ax, data, target, facecolor='c', edgecolor='c', classes=10, alpha=0.5):\n",
    "    markers = ['o', 's', 'D', 'v', '^', 'p', '*', 'H', 'd', 'X']\n",
    "    for i in range(classes):\n",
    "        data_i = data[target == i]\n",
    "        ax.scatter(\n",
    "            data_i[:100, 0],\n",
    "            data_i[:100, 1],\n",
    "            data_i[:100, 2],\n",
    "            marker=markers[i],\n",
    "            label=f'{i}',\n",
    "            facecolor=facecolor,\n",
    "            edgecolor=edgecolor,\n",
    "            alpha=alpha\n",
    "        )\n",
    "\n",
    "\n",
    "def convert_data_to_original(data, std, mean, bias, inverse_matrix):\n",
    "    data = data * std + mean\n",
    "    data = torch.mm(data - bias, inverse_matrix)\n",
    "    return data\n",
    "\n",
    "def decompose(data, compressed_dim=1, method='pca'):\n",
    "    if len(data.shape) == 2 and data.shape[1] == 1:\n",
    "        data = data.reshape(-1)\n",
    "    elif len(data.shape) == 2 and data.shape[1] > 1:\n",
    "        if method == 'pca':\n",
    "            pca = PCA(n_components=compressed_dim)\n",
    "            data = pca.fit_transform(data).reshape(-1)\n",
    "        elif method == 'tsne':\n",
    "            if data.shape[1] == 2:\n",
    "                return data\n",
    "            elif data.shape[1] > 2:\n",
    "                tsne = TSNE(n_components=compressed_dim)\n",
    "                data = tsne.fit_transform(data)\n",
    "            else:\n",
    "                raise ValueError(\"Decompose error: the dimension of data is not supported\")\n",
    "    return data\n",
    "\n",
    "def visualize_encoded_data(ax, data, target):\n",
    "    if len(data.shape) == 1:\n",
    "        data = data.reshape(-1, 1)\n",
    "\n",
    "    markers = ['o', 's', 'D', 'v', '^', 'p', '*', 'H', 'd', 'X']\n",
    "    # ratios = [target / len(target) for i in range(len(set(target)))]\n",
    "    data_part = [data[target == i] for i in range(len(set(target)))]\n",
    "\n",
    "    for i, data_i in enumerate(data_part):\n",
    "        ax.scatter(data_i[:50, 0], - (i + 1) * 0.25 * np.ones_like(data_i[:50, 0]) - 0.5, label=f'{i}', alpha=0.3, marker=markers[i])\n",
    "\n",
    "def visualize(ori_data, recon_data, enc_output, codes, iter, informative_dim, type):\n",
    "    fig = plt.figure(figsize=(22, 11))\n",
    "    ax2 = fig.add_subplot(1, 2, 2)\n",
    "    \n",
    "    # plot recon and ori\n",
    "    if informative_dim > 1:\n",
    "        \n",
    "        ori_data_numpy = ori_data[0].numpy()\n",
    "        recon_data_numpy = recon_data[0].numpy()\n",
    "\n",
    "        ori_target_numpy = ori_data[1].numpy()\n",
    "        recon_target_numpy = recon_data[1].numpy()\n",
    "\n",
    "        classes = len(set(ori_target_numpy))\n",
    "\n",
    "        if informative_dim > 3:\n",
    "            print('begin decompose')\n",
    "            tsne  =  TSNE(n_components=2)\n",
    "            \n",
    "            all_data = np.concatenate((ori_data_numpy, recon_data_numpy))\n",
    "            print('all_data shape{}'.format(all_data.shape))\n",
    "            all_data = tsne.fit_transform(all_data)\n",
    "            ori_data_numpy = all_data[ : ori_data_numpy.shape[0], :]\n",
    "            recon_data_numpy = all_data[ ori_data_numpy.shape[0]: , :]\n",
    "        \n",
    "        if informative_dim == 3:\n",
    "            ax1 = fig.add_subplot(1, 2, 1, projection='3d')\n",
    "            print('begin to plot 3d data')\n",
    "            generate_dots_pic_3d(\n",
    "                ax1, \n",
    "                ori_data_numpy, \n",
    "                ori_target_numpy, \n",
    "                classes=classes, \n",
    "                alpha=0.3, \n",
    "                facecolor='none', \n",
    "                edgecolor='c'\n",
    "            )\n",
    "            generate_dots_pic_3d(\n",
    "                ax1, \n",
    "                recon_data_numpy, \n",
    "                recon_target_numpy, \n",
    "                classes=classes, \n",
    "                alpha=0.3, \n",
    "                facecolor='b', \n",
    "                edgecolor='b'\n",
    "            )\n",
    "        else:\n",
    "            ax1 = fig.add_subplot(1, 2, 1)\n",
    "            generate_dots_pic(ax1, ori_data_numpy, ori_target_numpy, 'c', classes=classes)\n",
    "            generate_dots_pic(ax1, recon_data_numpy, recon_target_numpy, 'b', classes=classes)\n",
    "    else:\n",
    "        # Create a figure and axis object\n",
    "        # Plot the ground truth \n",
    "        ax1 = fig.add_subplot(1, 2, 1)\n",
    "        all_data = np.concatenate((ori_data.numpy(), recon_data.numpy()))\n",
    "        all_data = decompose(all_data, compressed_dim=1, method='pca')\n",
    "        ori_data = all_data[ : ori_data.shape[0]]\n",
    "        recon_data = all_data[ori_data.shape[0]: ]\n",
    "        plot_data = {\n",
    "            'Ground Truth': ori_data,\n",
    "            'Reconstruction': recon_data\n",
    "        }\n",
    "        print('compressed ori and recon to 1 D')\n",
    "        sns.kdeplot(plot_data, bw_adjust=0.2, color='green', ax=ax1,)\n",
    "\n",
    "\n",
    "    if enc_output.shape[1] > 1:\n",
    "        enc_output_numpy = enc_output.numpy()\n",
    "        print('decompose encoder output and codebook used tsne')\n",
    "        # compress codebook and encoder output used TSNE\n",
    "        if codes is not None:\n",
    "            codes_numpy = codes.numpy()\n",
    "            all_data = np.concatenate((enc_output_numpy, codes_numpy))\n",
    "             \n",
    "            all_data = decompose(all_data, 2, method='tsne')\n",
    "            enc_output_numpy = all_data[ : enc_output_numpy.shape[0], :]\n",
    "            codes_numpy = all_data[enc_output_numpy.shape[0]: , :]\n",
    "        else:\n",
    "            codes_numpy = codes\n",
    "            enc_output_numpy = decompose(enc_output_numpy, 2, method='tsne')\n",
    "    else:\n",
    "        print('reshape codes and encoder output directly')\n",
    "        enc_output_numpy = decompose(enc_output.numpy())\n",
    "        codes_numpy = decompose(codes.numpy()) if codes is not None else None\n",
    "\n",
    "    # Visualize the encoder output and codebook\n",
    "    # Plot the codebook and encoder distribution\n",
    "    # plot ax2\n",
    "    if informative_dim > 1 and len(enc_output_numpy.shape) > 1 and enc_output_numpy.shape[1] > 1:\n",
    "        # here have some problems. we can not plot 2D dimension plots because we do not pass targets when training\n",
    "        generate_dots_pic(ax2, enc_output_numpy, ori_target_numpy, '#6495ED', classes=classes)\n",
    "        if codes_numpy is not None:\n",
    "            ax2.scatter(codes_numpy[ : , 0], codes_numpy[ : , 1], color='#DC143C', alpha=0.8, label='Code') \n",
    "    elif informative_dim > 1 and len(enc_output_numpy.shape) == 1:\n",
    "        sns.kdeplot(enc_output_numpy, bw_adjust=0.2, color='black', label='encoder_output', ax=ax2)\n",
    "        # Plot the codebook distribution using histplot\n",
    "        if codes is not None:\n",
    "            print(\"codebook max: {}, codebook min: {}\".format(codes_numpy.max(), codes_numpy.min()))\n",
    "            sns.kdeplot(codes_numpy, bw_adjust=0.2, color='red', label='codebook', ax=ax2)\n",
    "            sns.histplot(x=codes_numpy, bins=48, kde=False, fill=True, alpha=0.1, stat='density', ax=ax2, label='codebook', color='red')\n",
    "\n",
    "            # Add scatter plot for codes\n",
    "            ax2.scatter(x=codes, y=-0.25 * np.zeros_like(codes), color='red', alpha=0.4)\n",
    "\n",
    "        if type == 'test':\n",
    "            visualize_encoded_data(ax2, enc_output_numpy, ori_target_numpy)\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "\n",
    "\n",
    "    print(\"encoder output max: {}, encoder output min: {}\".format(enc_output_numpy.max(), enc_output_numpy.min()))\n",
    "\n",
    "\n",
    "    # Add legend and labels\n",
    "    # ax1.legend()\n",
    "    ax2.legend(\n",
    "        loc='upper right',  \n",
    "        bbox_to_anchor=(1.05, 1),  \n",
    "        ncol=2,  \n",
    "        fontsize='small',  \n",
    "        markerscale=0.5,  \n",
    "        frameon=False,      \n",
    "    )\n",
    "\n",
    "\n",
    "    # Set title\n",
    "    title = 'Plot of Reconstruction Test epoch:{}'.format(iter) if type == 'test' else 'KDE Plot of Reconstruction Train iter:{}'.format(iter)\n",
    "    title2 = 'KDE plot of Encoder Output and Codebook Test epoch:{}'.format(iter) if type == 'test' else 'KDE Plot of Encoder Output and Codebook Train iter:{}'.format(iter)\n",
    "    ax1.set_title(title)\n",
    "    ax2.set_title(title2)\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f66ec32c",
   "metadata": {},
   "source": [
    "# Data Construction (High Dimensional)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7e130cd5-16d1-4f00-b93d-08d9114ab727",
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate clusters\n",
    "def generate_pairs(informative_dim=2, lower_bound = -1, upper_bound = 1, num_pairs=10, threshold=0.3):\n",
    "    pairs = []\n",
    "    \n",
    "    while len(pairs) < num_pairs:\n",
    "        # new_pair = (np.random.uniform(lower_bound, upper_bound), np.random.uniform(lower_bound, upper_bound))\n",
    "\n",
    "        new_pair = []\n",
    "        for j in range(informative_dim):\n",
    "            new_pair.append(np.random.uniform(lower_bound, upper_bound))\n",
    "        \n",
    "        # check if the new pair is valid\n",
    "        valid = True\n",
    "        for pair in pairs:\n",
    "            dist = np.linalg.norm(np.array(new_pair) - np.array(pair))\n",
    "            if dist < threshold:\n",
    "                valid = False\n",
    "                break\n",
    "        \n",
    "        if valid:\n",
    "            pairs.append(new_pair)\n",
    "    \n",
    "    return pairs\n",
    "\n",
    "def generate_random_matrix(n):\n",
    "    random_matrix = torch.randn(n, n)\n",
    "    b = torch.randn( n)\n",
    "    return random_matrix, b\n",
    "\n",
    "\n",
    "def test_model(model, test_loader, epoch, informative_dim, inverse_matrix, bias, mean, std):\n",
    "    # initi\n",
    "    model.eval()\n",
    "    recon_data = []; recon_target = []\n",
    "    ori_data = []; ori_target = []\n",
    "    enc_output_data = []\n",
    "    device = next(model.parameters()).device\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch, target in test_loader:\n",
    "            batch = batch.to(device)\n",
    "            target = target.to(device)\n",
    "            recon, enc_output, _ = model(batch)\n",
    "\n",
    "            # record the data\n",
    "            if informative_dim > 1:\n",
    "                batch = batch.cpu()\n",
    "                \n",
    "                ori_target.append(target.cpu())\n",
    "                recon_target.append(target.cpu())\n",
    "            \n",
    "            ori_data.append(batch.cpu())\n",
    "            recon_data.append(recon.cpu().detach())\n",
    "            enc_output_data.append(enc_output.cpu().detach())\n",
    "\n",
    "    codes = model.get_codebook()\n",
    "\n",
    "    if codes is not None:\n",
    "        codes = codes.cpu().detach()\n",
    "\n",
    "    # transform the data to numpy\n",
    "    recon_data = torch.cat(recon_data)\n",
    "    ori_data = torch.cat(ori_data)\n",
    "    enc_output_data = torch.cat(enc_output_data)\n",
    "\n",
    "    if informative_dim > 1:\n",
    "        ori_target = torch.cat(ori_target)\n",
    "        recon_target = torch.cat(recon_target)\n",
    "\n",
    "        ori_data = convert_data_to_original(ori_data, std, mean, bias, inverse_matrix)\n",
    "        recon_data = convert_data_to_original(recon_data, std, mean, bias, inverse_matrix)\n",
    "        \n",
    "        ori_data = (ori_data, ori_target)\n",
    "        recon_data = (recon_data, recon_target)\n",
    "\n",
    "    visualize(ori_data, recon_data, enc_output_data, codes, epoch, informative_dim, 'test')\n",
    "    model.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25e7abb8",
   "metadata": {},
   "source": [
    "# Load Configs for Pretrained Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "07877847",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_cluster_center_by_kmeans(test_loader, model, classes=10):\n",
    "    with torch.no_grad():\n",
    "        encoded = []\n",
    "        for batch, target in test_loader:\n",
    "            batch = batch\n",
    "            target = target\n",
    "            enc_output = model.encode(batch)\n",
    "            enc_output = enc_output.cpu().numpy()\n",
    "            encoded.append(enc_output)\n",
    "\n",
    "        encoded = np.concatenate(encoded, axis=0)\n",
    "        kmeans = KMeans(n_clusters=classes, random_state=0).fit(encoded)\n",
    "        cluster_center = kmeans.cluster_centers_\n",
    "        cluster_center = torch.tensor(cluster_center)\n",
    "        return cluster_center\n",
    "    \n",
    "def get_codebook_based_encoder_output(cluster_centers, feature_size, num_tokens, std=0.2):\n",
    "    with torch.no_grad():\n",
    "        std_matrix = torch.eye(feature_size) * std\n",
    "        classes = cluster_centers.shape[0]\n",
    "\n",
    "        per_class_num = num_tokens // classes\n",
    "        remaining = num_tokens % classes\n",
    "        codes = []\n",
    "\n",
    "        index = 0\n",
    "        for i in range(classes):\n",
    "            mean = cluster_centers[i]\n",
    "            distribution = torch.distributions.MultivariateNormal(mean, std_matrix)\n",
    "            \n",
    "            # allocate extra codes for first classes\n",
    "            num_samples = per_class_num + (1 if i < remaining else 0)\n",
    "            samples = distribution.sample((num_samples,))\n",
    "            codes.append(samples)\n",
    "            \n",
    "            index += num_samples\n",
    "\n",
    "        assert index == num_tokens, \"All codebook entries should be initialized\"\n",
    "        \n",
    "        return torch.cat(codes, dim=0)\n",
    "\n",
    "\n",
    "def load_model_state_dict(model, input_size, encode_size, hidden_encoder_size, hidden_decoder_size):\n",
    "    Autoencoder_path = os.path.join('./Auto-Encoders', 'AutoEncoder_{}_{}_{}_{}.pth'.format(input_size, encode_size, hidden_encoder_size, hidden_decoder_size))\n",
    "    model.load_state_dict(torch.load(Autoencoder_path, map_location='cpu', weights_only=True))\n",
    "    print(\"Load Autoencoder from {}\".format(Autoencoder_path))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92486d8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 4568\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "torch.backends.cudnn.enable = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "\n",
    "# dimension of data\n",
    "informative_dim = 3\n",
    "dim = 3\n",
    "\n",
    "# samples of large peak\n",
    "n_samples = [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000] \n",
    "means = generate_pairs(informative_dim, -10, 10, 10, 5.0)\n",
    "print(means)\n",
    "std =  [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]\n",
    "\n",
    "\n",
    "# generate data in each axis\n",
    "data_axis_train = []\n",
    "data_axis_test = []\n",
    "target_train = []\n",
    "target_test = []\n",
    "\n",
    "cluster_center_ori = []\n",
    "\n",
    "for i in range(len(n_samples)):\n",
    "    mean_axis = torch.zeros(informative_dim)\n",
    "\n",
    "    for j in range(informative_dim):\n",
    "        mean_axis[j] = means[i][j]\n",
    "\n",
    "    cluster_center_ori.append(mean_axis.unsqueeze(0))\n",
    "        \n",
    "    cov_axis = torch.eye(informative_dim) * std[i]\n",
    "    \n",
    "    # get distribution\n",
    "    dist_axis = torch.distributions.MultivariateNormal(mean_axis, cov_axis)\n",
    "    \n",
    "    # sample training data\n",
    "    # data part \n",
    "    data_train = dist_axis.sample((int(n_samples[i] * 0.8),)) \n",
    "    data_train_extended = torch.zeros(data_train.shape[0], dim)\n",
    "    data_train_extended[:, :informative_dim] = data_train\n",
    "    data_train = data_train_extended\n",
    "\n",
    "    # target part\n",
    "    target_train_i = torch.full((data_train.shape[0],), i, dtype=torch.long)\n",
    "\n",
    "\n",
    "    # sample testing data\n",
    "    data_test = dist_axis.sample((int(n_samples[i] * 0.2),)) \n",
    "    data_test_extended = torch.zeros(data_test.shape[0], dim)\n",
    "    data_test_extended[:, :informative_dim] = data_test\n",
    "    data_test = data_test_extended\n",
    "\n",
    "    # target part\n",
    "    target_test_i = torch.full((data_test.shape[0],), i, dtype=torch.long)\n",
    "\n",
    "    # append data\n",
    "    data_axis_train.append(data_train)\n",
    "    data_axis_test.append(data_test)\n",
    "\n",
    "    # append target\n",
    "    target_train.append(target_train_i)\n",
    "    target_test.append(target_test_i)\n",
    "\n",
    "\n",
    "\n",
    "data_train_ori = torch.cat(data_axis_train, dim=0)\n",
    "data_test_ori = torch.cat(data_axis_test, dim=0)\n",
    "target_train = torch.cat(target_train, dim=0)\n",
    "target_test = torch.cat(target_test, dim=0)\n",
    "\n",
    "if dim == informative_dim:\n",
    "    rotation_matrix = torch.eye(dim)\n",
    "    b = torch.zeros(dim)\n",
    "else:\n",
    "    rotation_matrix, b = generate_random_matrix(dim)\n",
    "\n",
    "data_train = torch.mm(data_train_ori, rotation_matrix) + b\n",
    "data_test = torch.mm(data_test_ori, rotation_matrix) + b\n",
    "\n",
    "data_all = torch.cat([data_train, data_test], dim=0)\n",
    "\n",
    "data_mean = data_all.mean(dim=0)\n",
    "data_std = data_all.std(dim=0)\n",
    "\n",
    "# normalize data to 0, 1\n",
    "data_train = (data_train - data_mean) / data_std\n",
    "data_test = (data_test - data_mean) / data_std\n",
    "\n",
    "# record the original format \n",
    "inverse_matrix = torch.inverse(rotation_matrix)\n",
    "bias = b\n",
    "\n",
    "if informative_dim > 3:\n",
    "   \n",
    "    _, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 4))\n",
    "    \n",
    "    # t-sne\n",
    "    tsne = TSNE(n_components=2)\n",
    "    data_test_tsne = tsne.fit_transform(data_test)\n",
    "    \n",
    "    ax3.scatter(data_test_tsne[:, 0], data_test_tsne[:, 1], color='red', label='train', alpha=0.5)\n",
    "    ax3.set_title('t-SNE to 2 Dimension')\n",
    "elif informative_dim == 3:\n",
    "    ax = plt.axes(projection='3d')\n",
    "    generate_dots_pic_3d(ax, data_test_ori, target_test, classes=len(n_samples), facecolor='c', edgecolor='c')\n",
    "else:\n",
    "    _, ax1 = plt.subplots(1, 1, figsize=(6, 6))\n",
    "    generate_dots_pic(ax1, data_test_ori, target_test, 'c', classes=len(n_samples))\n",
    "\n",
    "    data_test_trans = data_test.clone()\n",
    "    # data_test_trans = convert_data_to_original(data_test_trans, data_std, data_mean, bias, inverse_matrix)\n",
    "    \n",
    "    generate_dots_pic(ax1, data_test_trans, target_test, 'b', classes=10)\n",
    "    ax1.set_title('original data')\n",
    "\n",
    "\n",
    "\n",
    "plt.show()\n",
    "train_dataset = TensorDataset(data_train, target_train)\n",
    "test_dataset = TensorDataset(data_test, target_test)\n",
    "\n",
    "print(cluster_center_ori)\n",
    "cluster_center_ori = torch.cat(cluster_center_ori, dim=0)\n",
    "\n",
    "cluster_center_ori = (cluster_center_ori - data_mean) / data_std\n",
    "\n",
    "print(cluster_center_ori.shape)\n",
    "\n",
    "# get inverse matrix and bias\n",
    "assert data_train.shape[0] == target_train.shape[0]\n",
    "assert data_test.shape[0] == target_test.shape[0]\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d217ef82",
   "metadata": {},
   "source": [
    "# Model definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fabebeae-4f66-4cb6-aaa2-4ea3fd423f79",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model definition\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_size, output_size, hidden_size=16):\n",
    "        super().__init__()\n",
    "        self.liner1 = nn.Linear(input_size, hidden_size)\n",
    "        self.liner2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.liner3 = nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.liner1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.liner2(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.liner3(x)\n",
    "        return x\n",
    "\n",
    "class VQ_Tiny(nn.Module):\n",
    "    def __init__(\n",
    "                    self, \n",
    "                    num_tokens, \n",
    "                    feature_size, \n",
    "                    input_size, \n",
    "                    output_size, \n",
    "                    encode_size=1, \n",
    "                    hidden_encoder_size=16,\n",
    "                    hidden_decoder_size=16,\n",
    "                    decay=0.99, \n",
    "                    eps=1e-5, \n",
    "                    kmean_init=False, \n",
    "                    use_quantize=True,\n",
    "                    use_linear_project=False,\n",
    "                    compressed_size = 1\n",
    "                 ):\n",
    "        super().__init__()\n",
    "        self._encoder = MLP(input_size, encode_size, hidden_encoder_size)\n",
    "        self.use_linear_project = use_linear_project\n",
    "        if use_linear_project:\n",
    "            self.project_low_dim = nn.Linear(encode_size, compressed_size)\n",
    "            self.project_high_dim = nn.Linear(compressed_size, encode_size)\n",
    "        \n",
    "        self._decoder = MLP(encode_size, output_size, hidden_decoder_size) \n",
    "\n",
    "        # codebook\n",
    "        self.codebook = nn.Parameter(torch.randn((num_tokens, feature_size)), requires_grad=False)\n",
    "\n",
    "        self.num_tokens = num_tokens\n",
    "        self.feature_size = feature_size\n",
    "        \n",
    "        # EMA parameters\n",
    "        self.register_buffer('cluster_size', torch.ones(num_tokens))\n",
    "        self.register_buffer('embed_avg', self.codebook.clone())\n",
    "\n",
    "        self.use_quantize = use_quantize\n",
    "\n",
    "        self.decay = decay\n",
    "        self.eps = eps\n",
    "\n",
    "        # inilization parameters\n",
    "        self.initialized = False\n",
    "        self.kmean_init = kmean_init\n",
    "\n",
    "    def init_codebook(self, x):\n",
    "        if self.training and self.use_quantize:\n",
    "            if self.initialized:\n",
    "                print('Codebook already initialized')\n",
    "                return\n",
    "            else:\n",
    "                self._initialize_codebook(x, kmean_init=self.kmean_init)\n",
    "                self.initialized = True\n",
    "        else:\n",
    "            print('You do not need to initialize codebook')\n",
    "\n",
    "    def encode(self, x):\n",
    "        output = self._encoder(x)\n",
    "        if self.use_linear_project:\n",
    "            output =self.project_low_dim(output)\n",
    "        # print(output.shape)\n",
    "        return output\n",
    "\n",
    "    def forward(self, x):\n",
    "        enc_output = self._encoder(x)\n",
    "        if self.use_linear_project:\n",
    "            enc_output = self.project_low_dim(enc_output)\n",
    "        if self.use_quantize:\n",
    "            z, loss = self.quantize(enc_output)\n",
    "        else:\n",
    "            z = enc_output\n",
    "            loss = 0.0\n",
    "\n",
    "        if self.use_linear_project:\n",
    "            z = self.project_high_dim(z)\n",
    "        x_recon = self._decoder(z)\n",
    "        return x_recon, enc_output, loss\n",
    "\n",
    "\n",
    "    def quantize(self, z_e):\n",
    "\n",
    "        # first foward in training, initialize codebook\n",
    "        if self.training and not self.initialized:\n",
    "            self._initialize_codebook(z_e, kmean_init=self.kmean_init)\n",
    "            self.initialized = True\n",
    "\n",
    "        # if high dimension, flatten the input like (B, H, W, C) -> (B*H*W, C) but in our setting, it is not necessary\n",
    "        z_flatten = z_e.view(-1, self.feature_size)\n",
    "\n",
    "        # calculate the distance between the input and the codebook\n",
    "        dis = z_flatten.pow(2).sum(dim=1, keepdim=True) - 2 * torch.einsum('ik,jk->ij', z_flatten, self.codebook) + self.codebook.pow(2).sum(dim=1, keepdim=True).T\n",
    "        \n",
    "        # find the nearest neighbor's index\n",
    "        indices = torch.argmin(dis, dim=1)\n",
    "\n",
    "        # get the nearest neighbor's feature\n",
    "        z_q = F.embedding(indices, self.codebook)\n",
    "        \n",
    "\n",
    "        # get the one hot encoding # (B, ) -> (B, num_tokens) \n",
    "        encodings = F.one_hot(indices, num_classes=self.num_tokens) #\n",
    "\n",
    "        # if training, update the codebook\n",
    "        if self.training and self.use_quantize:\n",
    "            with torch.no_grad():\n",
    "                # update the cluster size and the embed_avg\n",
    "                \n",
    "                # cluster_size = cluster_size * decay + encodings.sum(0) * (1 - decay)\n",
    "                self.cluster_size.data.mul_(self.decay)\n",
    "                self.cluster_size.data.add_(encodings.sum(0), alpha=1 - self.decay)\n",
    "\n",
    "                # embed_avg = embed_avg * decay + encodings.transpose(0, 1) X  z_flatten * (1 - decay)\n",
    "                self.embed_avg.data.mul_(self.decay)\n",
    "                new_sum_embed = encodings.transpose(0, 1).type(z_flatten.dtype) @ z_flatten\n",
    "\n",
    "                self.embed_avg.data.add_(new_sum_embed, alpha=1 - self.decay)\n",
    "\n",
    "                # smooth the cluster size\n",
    "                n = self.cluster_size.sum()\n",
    "                smoothed_cluster_size = ((self.cluster_size + self.eps) / \n",
    "                                         (n + self.num_tokens * self.eps) ) * n\n",
    "                \n",
    "\n",
    "                # compute the codebook information and update it\n",
    "                codebook_updated = self.embed_avg / smoothed_cluster_size.unsqueeze(1)\n",
    "                self.codebook.data.copy_(codebook_updated)\n",
    "    \n",
    "\n",
    "\n",
    "        z_q = z_q.reshape(-1, self.feature_size)\n",
    "\n",
    "        # loss\n",
    "        loss = F.mse_loss(z_e,z_q.detach())\n",
    "\n",
    "\n",
    "        # Straight-through estimator\n",
    "        z_q = z_e + (z_q - z_e).detach()\n",
    "\n",
    "        return z_q, loss\n",
    "\n",
    "    # initialize codebook with first batch data\n",
    "    def _initialize_codebook(self, x, kmean_init=False):\n",
    "        with torch.no_grad():\n",
    "            # flatten the input\n",
    "            flat_x = x.reshape(-1, self.feature_size)\n",
    "            print('initialization shape {}'.format(flat_x.shape))\n",
    "            if kmean_init:\n",
    "                # use kmean to initialize the codebook\n",
    "                print('kmeans initialization!')\n",
    "                kmeans = KMeans(n_clusters=self.num_tokens, random_state=0).fit(flat_x.cpu().numpy())\n",
    "                centers = torch.tensor(\n",
    "                    kmeans.cluster_centers_,\n",
    "                    dtype=self.codebook.dtype,\n",
    "                )\n",
    "                \n",
    "                self.codebook.data.copy_(centers)\n",
    "                self.embed_avg.data.copy_(self.codebook.data)\n",
    "\n",
    "            else:\n",
    "                # randomly sampling from encoder output and update the codebook\n",
    "                indices = torch.randperm(flat_x.shape[0])[:self.num_tokens]\n",
    "                self.codebook.data.copy_(flat_x[indices])\n",
    "                self.embed_avg.data.copy_(self.codebook.data)\n",
    "        self.initialized = True\n",
    "        print(\"Codebook initialized with first batch data.\")\n",
    "\n",
    "\n",
    "    def get_codebook(self):\n",
    "        if self.use_quantize:\n",
    "            return self.codebook.data\n",
    "        else:\n",
    "            return None\n",
    "\n",
    "\n",
    "    def set_codebook_align_distribution(self, codebook_distribution):\n",
    "        with torch.no_grad():\n",
    "            print('Init codebook with manually set distribution!')\n",
    "            self.codebook.data.copy_(codebook_distribution)\n",
    "            self.embed_avg.data.copy_(codebook_distribution)\n",
    "            self.initialized = True\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2d35ceb",
   "metadata": {},
   "source": [
    "# Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b04be5fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model and optimizer\n",
    "batch_size = 256\n",
    "learning_rate = 0.01\n",
    "num_epochs = 500\n",
    "\n",
    "cmt_loss_weight = 0.25\n",
    "recon_loss_weight = 1.0\n",
    "\n",
    "pretrained = False\n",
    "mannual_set_codebook = False\n",
    "\n",
    "\n",
    "# find cluster by kmeans\n",
    "mannual_find_cluster_center_kmeans = False\n",
    "\n",
    "\n",
    "init_batch = 0\n",
    "\n",
    "configs = {\n",
    "    \"num_tokens\": 128,\n",
    "    \"feature_size\": 1,\n",
    "    \"input_size\":  dim,\n",
    "    \"output_size\": dim,\n",
    "    \"hidden_encoder_size\": 4,\n",
    "    \"hidden_decoder_size\": 32,\n",
    "    \"encode_size\": 1, \n",
    "    \"kmean_init\": True,\n",
    "    \"use_quantize\": True,\n",
    "    \"decay\": 0.9,\n",
    "    \"eps\": 1e-4,\n",
    "    \"use_linear_project\": False,\n",
    "    \"compressed_size\": 1\n",
    "}\n",
    "\n",
    "assert (configs[\"encode_size\"] == configs[\"feature_size\"] or configs[\"use_linear_project\"]) , \"encode_size must be equal to feature_size\"\n",
    "\n",
    "model = VQ_Tiny(**configs)\n",
    "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "\n",
    "\n",
    "# load data \n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "init_codebook = None\n",
    "# load pretrained model if applicable\n",
    "with torch.no_grad():\n",
    "    if pretrained and configs[\"use_quantize\"]:\n",
    "        model = load_model_state_dict(\n",
    "            model=model, \n",
    "            input_size=configs[\"input_size\"], \n",
    "            encode_size=configs[\"encode_size\"], \n",
    "            hidden_encoder_size=configs[\"hidden_encoder_size\"],\n",
    "            hidden_decoder_size=configs[\"hidden_decoder_size\"]\n",
    "        )\n",
    "        if mannual_find_cluster_center_kmeans:\n",
    "            cluster_center = find_cluster_center_by_kmeans(test_loader, model)\n",
    "            print('cluster center(kmeans):{}'.format(cluster_center))\n",
    "            init_codebook = get_codebook_based_encoder_output(cluster_center, configs[\"feature_size\"], configs[\"num_tokens\"])\n",
    "        elif mannual_set_codebook:\n",
    "            cluster_center = model.encode(cluster_center_ori)\n",
    "            print('cluster center:{}'.format(cluster_center))\n",
    "            init_codebook = get_codebook_based_encoder_output(cluster_center, configs[\"feature_size\"], configs[\"num_tokens\"], std=0.01)\n",
    "            print('Initialization using manually set codebook based on pretrained encoder output!')\n",
    "            model.set_codebook_align_distribution(init_codebook)\n",
    "    elif not pretrained and configs[\"use_quantize\"] and mannual_set_codebook:\n",
    "        print(cluster_center_ori)\n",
    "        cluster_center = model.encode(cluster_center_ori)\n",
    "        print(cluster_center)\n",
    "        init_codebook = get_codebook_based_encoder_output(cluster_center, configs[\"feature_size\"], configs[\"num_tokens\"], std=1e-7)\n",
    "        # print('init codebook:{}'.format(init_codebook))\n",
    "        model.set_codebook_align_distribution(init_codebook)\n",
    "\n",
    "\n",
    "# move to GPU\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3702c50e-a336-4f88-a3bd-235c75b80146",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# train and test\n",
    "# initialize codebook\n",
    "model.train()\n",
    "init = []\n",
    "if init_batch >= 1:\n",
    "    with torch.no_grad():\n",
    "        for i, (inputs, target) in enumerate(train_loader):\n",
    "            if i >= init_batch:\n",
    "                break\n",
    "            init.append(model.encode(inputs.to(device)))\n",
    "        print('init here')\n",
    "        init = torch.cat(init, dim=0)\n",
    "        model.init_codebook(init)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    ori_data = []\n",
    "    ori_target = []\n",
    "\n",
    "    recon_data = []\n",
    "    recon_target = []\n",
    "\n",
    "    running_recon_loss = 0.0\n",
    "\n",
    "    enc_output_data = []\n",
    "    \n",
    "    for i, (inputs, target) in enumerate(train_loader):\n",
    "        inputs = inputs.to(device)\n",
    "        ori_data.append(inputs.cpu())\n",
    "        ori_target.append(target.cpu())\n",
    "        recon_target.append(target.cpu())\n",
    "        \n",
    "        # forward\n",
    "        recon, enc_output, cmt_loss = model(inputs)\n",
    "\n",
    "        recon_loss = F.mse_loss(inputs, recon)\n",
    "        loss = cmt_loss * cmt_loss_weight + recon_loss * recon_loss_weight\n",
    "\n",
    "        # backward\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "\n",
    "        running_recon_loss += recon_loss.item()\n",
    "\n",
    "        # record the data\n",
    "        recon_data.append(recon.cpu().detach())\n",
    "        enc_output_data.append(enc_output.cpu().detach())\n",
    "        \n",
    "        if (i == 0 and epoch == 0) or (i + 1) % 100 == 0:\n",
    "            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.8f}')\n",
    "            \n",
    "            # transform the data to numpy\n",
    "            recon_data = torch.cat(recon_data)\n",
    "            ori_data = torch.cat(ori_data)\n",
    "\n",
    "            # convert the data to original space\n",
    "            ori_data = convert_data_to_original(ori_data, data_std, data_mean, bias, inverse_matrix)\n",
    "            recon_data = convert_data_to_original(recon_data, data_std, data_mean, bias, inverse_matrix)\n",
    "\n",
    "            ori_target = torch.cat(ori_target)\n",
    "            recon_target = torch.cat(recon_target)\n",
    "\n",
    "            ori_data = (ori_data, ori_target)\n",
    "            recon_data = (recon_data, recon_target)\n",
    "\n",
    "            enc_output_data = torch.cat(enc_output_data)\n",
    "            codes = model.get_codebook()\n",
    "\n",
    "\n",
    "            if codes is not None:\n",
    "                codes = codes.cpu().detach()\n",
    "            visualize(ori_data, recon_data, enc_output_data, codes, i + 1, informative_dim, 'train')\n",
    "            \n",
    "            # reset the record list\n",
    "            recon_data = []\n",
    "            ori_data = []\n",
    "            ori_target = []\n",
    "            recon_target = []\n",
    "            enc_output_data = []\n",
    "\n",
    "    if epoch == 0 or (epoch + 1) % 25 == 0:         \n",
    "        print(f'Epoch [{epoch + 1}/{num_epochs}], Average Loss: {running_loss / len(train_loader):.4f}')\n",
    "        print(f'Epoch [{epoch + 1}/{num_epochs}], Average Loss: {running_recon_loss / len(train_loader):.4f}')\n",
    "        test_model(model, test_loader, epoch, informative_dim, inverse_matrix, bias, data_mean, data_std)\n",
    "\n",
    "# save model\n",
    "if configs[\"use_quantize\"]:\n",
    "    model_name = 'VQ_MLP_quantized.pth'\n",
    "else:\n",
    "    model_name = './Auto-Encoders/AutoEncoder_{}_{}_{}_{}.pth'.format(configs[\"input_size\"], configs[\"encode_size\"], configs[\"hidden_encoder_size\"], configs[\"hidden_decoder_size\"])\n",
    "\n",
    "torch.save(model.state_dict(), model_name)\n",
    "print('Model saved to ' + model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2166fc1-2422-4be7-8a4f-ba75392691fd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ad1395f-0cb4-432d-b572-7e4cbe505fc5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "EMA",
   "language": "python",
   "name": "ema"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
