{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7db7dSTR85sI"
   },
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "YtmEPqCJ860F"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4k4MtU8r879E"
   },
   "source": [
    "# Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "The dictionary represents a mapping where the keys represent the number of labeled data.\n",
    "The corresponding values are lists indicating the number of unlabeled data for each fixed number of labeled data.\n",
    "\"\"\" \n",
    "TRAINING_CONFIG = {\n",
    "    10: {\n",
    "        \"M_BUCKET\": [10, 100, 1000, 10000],\n",
    "    },\n",
    "    20: {\n",
    "        \"M_BUCKET\": [20, 200, 2000, 10000],\n",
    "    },\n",
    "    40: {\n",
    "        \"M_BUCKET\": [40, 400, 4000, 10000],\n",
    "    }\n",
    "}\n",
    "\n",
    "# Search spaces of linear model, and robust linear model.\n",
    "\n",
    "LINEAR_PARAM_GRID = {\n",
    "    'gamma': [15, 0.001],\n",
    "    'weight_decay': [1, 0],\n",
    "}\n",
    "\n",
    "ROBUST_LINEAR_PARAM_GRID = {\n",
    "    'gamma': [20, 0.001],\n",
    "    'gamma_unlabeled': [20, 0.0001],\n",
    "    'l': [20, 0],\n",
    "    'weight_decay': [1, 0],\n",
    "}\n",
    "\n",
    "# The rate of perturbation\n",
    "ALPHA_RATIO = 0.5\n",
    "\n",
    "# The amount of data for achieving the maximum accuracy\n",
    "N_MAX = 10000\n",
    "\n",
    "# Dimension of data\n",
    "DIM = 200\n",
    "\n",
    "STD = 1\n",
    "TEST_SIZE = 10000\n",
    "\n",
    "# Number of the combinations of each model's hyperparameters\n",
    "ROBUST_LINEAR_NUM_COMBINATION = 4000\n",
    "LINEAR_NUM_COMBINATION = 500"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wJRcmZvpRbzB"
   },
   "source": [
    "# Hyperparameter Tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "8k-uWOqqRd57"
   },
   "outputs": [],
   "source": [
    "import itertools\n",
    "import random\n",
    "\n",
    "def get_tuned_model(model_class, X_train, y_train, X_test, y_test, param_grid, d, num_combinations, X_unlabeled=None):\n",
    "    best_params = None\n",
    "    best_model = None\n",
    "    best_accuracy = 0\n",
    "    param_combinations = []\n",
    "    for i in range(num_combinations):\n",
    "        params = None\n",
    "        if 'gamma_unlabeled' in param_grid: # To ensure that the unlabeled gamma does not exceed the gamma value for a robust linear model\n",
    "            while True: \n",
    "                params = {k: random.uniform(v[1], v[0])  for k, v in param_grid.items()}\n",
    "                if params['gamma_unlabeled'] < params['gamma']:\n",
    "                    break\n",
    "        else: # For linear model\n",
    "            params = {k: random.uniform(v[1], v[0])  for k, v in param_grid.items()}\n",
    "\n",
    "        param_combinations.append(params)\n",
    "    \n",
    "    for params in param_combinations:\n",
    "        model = model_class(d, **params)\n",
    "        if X_unlabeled is not None:\n",
    "            model.fit(X_train, y_train, X_unlabeled)\n",
    "        else:\n",
    "            model.fit(X_train, y_train)\n",
    "\n",
    "        accuracy = model.score(X_test, y_test)\n",
    "\n",
    "        if accuracy > best_accuracy:\n",
    "            best_params = params\n",
    "            best_accuracy = accuracy\n",
    "            best_model = model\n",
    "\n",
    "    print(\"Best Parameters: \", best_params)\n",
    "    return best_model, best_params, best_accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "L9ZMXQbUdQP1"
   },
   "outputs": [],
   "source": [
    "class LinearModel(torch.nn.Module):\n",
    "    def __init__(self, input_size):\n",
    "        super(LinearModel, self).__init__()\n",
    "        self.linear = nn.Linear(input_size, 1, bias=False)\n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        torch.nn.init.xavier_uniform_(self.linear.weight)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.linear(x)\n",
    "\n",
    "class Trainer():\n",
    "    def __init__(self, num_features, learning_rate=0.8, gamma=2, weight_decay=0.0001, num_epochs=10, num_iters=5):\n",
    "        self.learning_rate = learning_rate\n",
    "        self.num_features = num_features\n",
    "        self.gamma = gamma\n",
    "        self.weight_decay = weight_decay\n",
    "        self.num_epochs = num_epochs\n",
    "        self.num_iters = num_iters\n",
    "        self.best_model = None\n",
    "    \n",
    "    def get_model_and_optimizers(self):\n",
    "        model = LinearModel(self.num_features)\n",
    "\n",
    "        optimizer = optim.Adam(\n",
    "            model.parameters(),\n",
    "            lr=self.learning_rate,\n",
    "            weight_decay=self.weight_decay,\n",
    "        )\n",
    "        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=False)\n",
    "        return model, optimizer, scheduler\n",
    "\n",
    "    def loss_fn(self, *args):\n",
    "        pass\n",
    "\n",
    "    def run(self, *args):\n",
    "        pass\n",
    "\n",
    "    def fit(self, *args):\n",
    "        pass\n",
    "\n",
    "    def predict(self, X):\n",
    "        self.best_model.eval()\n",
    "        X = torch.tensor(X, dtype=torch.float32, requires_grad=False)\n",
    "        y_pred = torch.sign(self.best_model(X)).squeeze().detach()\n",
    "        return y_pred.numpy()\n",
    "\n",
    "    def score(self, X, y):\n",
    "        y_pred = self.predict(X)\n",
    "        return (y_pred == y).mean()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "Dm2_f7HAZhhS"
   },
   "outputs": [],
   "source": [
    "class LinearTrainer(Trainer):\n",
    "    def __init__(self, num_features, learning_rate=0.07, gamma=2, weight_decay=0.0001, num_epochs=20, num_iters=15):\n",
    "        super().__init__(\n",
    "            num_features,\n",
    "            learning_rate,\n",
    "            gamma,\n",
    "            weight_decay,\n",
    "            num_epochs,\n",
    "            num_iters,\n",
    "        )\n",
    "  \n",
    "    def loss_fn(self, signed_dists):\n",
    "        loss = torch.clamp(1 - self.gamma * signed_dists, min=0, max=1).mean()\n",
    "        return loss\n",
    "\n",
    "\n",
    "    def run(self, X, y):\n",
    "        model, optimizer, scheduler = self.get_model_and_optimizers()\n",
    "        model.train()\n",
    "        best_checkpoint = None\n",
    "        best_loss = np.inf\n",
    "        for epoch in range(self.num_epochs):\n",
    "            optimizer.zero_grad()\n",
    "            distances = model(X)\n",
    "            signed_distances = y * distances\n",
    "            \n",
    "            # Calculating loss\n",
    "            loss = self.loss_fn(signed_distances)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            scheduler.step(loss)\n",
    "\n",
    "            if loss.item() <= best_loss:\n",
    "                best_loss = loss.item()\n",
    "                best_checkpoint = model.state_dict()\n",
    "\n",
    "                \n",
    "        # Load the best checkpoint\n",
    "        best_model = LinearModel(self.num_features)\n",
    "        for name, param in best_checkpoint.items():\n",
    "            if name.startswith('linear'):\n",
    "                best_model.linear.weight.data.copy_(param)\n",
    "\n",
    "        return best_model, loss\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        X = torch.tensor(X, dtype=torch.float32, requires_grad=False)\n",
    "        y = torch.tensor(y, dtype=torch.float32, requires_grad=False).unsqueeze(1)\n",
    "        best_loss = np.inf\n",
    "        for _ in range(self.num_iters):\n",
    "            model, loss = self.run(X, y)\n",
    "            if loss < best_loss:\n",
    "                best_loss = loss\n",
    "                self.best_model = model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "id": "SqCePToKAIVA"
   },
   "outputs": [],
   "source": [
    "class RobustTrainer(Trainer):\n",
    "    def __init__(self, num_features, learning_rate=0.07, gamma=2, weight_decay=0.7, num_epochs=20, num_iters=20, gamma_unlabeled=1, l=0.01):\n",
    "        super().__init__(\n",
    "            num_features,\n",
    "            learning_rate,\n",
    "            gamma,\n",
    "            weight_decay,\n",
    "            num_epochs,\n",
    "            num_iters,\n",
    "        )\n",
    "        self.gamma_unlabeled = gamma_unlabeled\n",
    "        self.l = l\n",
    "\n",
    "    def loss_fn(self, signed_dists, unlabeled_dists):\n",
    "        loss = torch.clamp(1 - self.gamma * signed_dists, min=0, max=1).mean()\n",
    "        if len(unlabeled_dists):\n",
    "            loss += self.l * torch.clamp(1 - self.gamma_unlabeled * unlabeled_dists, min=0, max=1).mean()\n",
    "        return loss\n",
    "\n",
    "    def run(self, X, y, X_unlabeled):\n",
    "        model, optimizer, scheduler = self.get_model_and_optimizers()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        model.train()\n",
    "        best_checkpoint = None\n",
    "        best_loss = np.inf\n",
    "        for epoch in range(self.num_epochs):\n",
    "            optimizer.zero_grad()\n",
    "            # For labeled data\n",
    "            distances = model(X)\n",
    "            signed_distances = y * distances\n",
    "            \n",
    "            # For unlabeled data\n",
    "            unlabeled_distances = model(X_unlabeled)\n",
    "            abs_unlabeled_distances = torch.abs(unlabeled_distances)\n",
    "            \n",
    "            # Calculating loss\n",
    "            loss = self.loss_fn(signed_distances, abs_unlabeled_distances)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            scheduler.step(loss)\n",
    "\n",
    "            if loss.item() <= best_loss:\n",
    "                best_loss = loss.item()\n",
    "                best_checkpoint = model.state_dict()\n",
    "\n",
    "\n",
    "        # Load the best checkpoint\n",
    "        best_model = LinearModel(self.num_features)\n",
    "        for name, param in best_checkpoint.items():\n",
    "            if name.startswith('linear'):\n",
    "                best_model.linear.weight.data.copy_(param)\n",
    "        return best_model, loss\n",
    "\n",
    "    def fit(self, X, y, X_unlabeled):\n",
    "        X = torch.tensor(X, dtype=torch.float32)\n",
    "        X_unlabeled = torch.tensor(X_unlabeled, dtype=torch.float32)\n",
    "        y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)\n",
    "\n",
    "        best_loss = np.inf\n",
    "        for _ in range(self.num_iters):\n",
    "            model, loss = self.run(X, y, X_unlabeled)\n",
    "            if loss < best_loss:\n",
    "                best_loss = loss\n",
    "                self.best_model = model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_mu(d, std):\n",
    "    # mu is normalized\n",
    "    unit_vector = np.random.randn(d)\n",
    "    unit_vector /= np.linalg.norm(unit_vector)\n",
    "    mu = unit_vector * 1\n",
    "    return mu\n",
    "\n",
    "def generate_multivariate_gaussian_samples(n_samples, mu, std):\n",
    "    d = len(mu)\n",
    "    covariance_matrix = np.identity(d) * (std ** 2)\n",
    "    samples = np.random.multivariate_normal(mu, covariance_matrix, n_samples)\n",
    "    return samples\n",
    "\n",
    "\n",
    "def generate_data(n_samples, mu, std):\n",
    "    mu_x = mu\n",
    "    neg_mu_x = -1 * mu_x\n",
    "    x_pos = generate_multivariate_gaussian_samples(int(n_samples/2), mu_x, std)\n",
    "    X_neg = generate_multivariate_gaussian_samples(int(n_samples/2), neg_mu_x, std)\n",
    "    X = np.concatenate((x_pos, X_neg))\n",
    "\n",
    "    y = np.ones((n_samples))\n",
    "    y[n_samples//2:] = -1\n",
    "    y = np.array(y, dtype=int)\n",
    "    return X, y\n",
    "\n",
    "\n",
    "def generate_unlabeled_data(n_samples, mu, std, alpha_ratio=ALPHA_RATIO):\n",
    "    d = len(mu)\n",
    "    \n",
    "    v = np.random.randn(d)\n",
    "    v /= np.linalg.norm(v)\n",
    "    \n",
    "    mu_x = mu + alpha_ratio * np.linalg.norm(mu) * v\n",
    "    neg_mu_x = -1 * mu_x\n",
    "    \n",
    "    x_pos = generate_multivariate_gaussian_samples(int(n_samples/2), mu_x, std)\n",
    "    X_neg = generate_multivariate_gaussian_samples(int(n_samples/2), neg_mu_x, std)\n",
    "    X = np.concatenate((x_pos, X_neg))\n",
    "\n",
    "    y = 2 * np.ones((n_samples))\n",
    "    y[n_samples//2:] = -2\n",
    "    y = np.array(y, dtype=int)\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best Parameters:  {'gamma': 9.601166867580856, 'weight_decay': 0.3589276260764547}\n",
      "linear model accuracy at N=10 = 0.6174\n"
     ]
    }
   ],
   "source": [
    "mu = generate_mu(DIM, STD)\n",
    "\n",
    "X_test, y_test = generate_data(TEST_SIZE , mu, STD)\n",
    "X_unlabeled, y_unlabeled = generate_unlabeled_data(TEST_SIZE, mu, STD)\n",
    "\n",
    "training_outputs = {}\n",
    "for n, value in TRAINING_CONFIG.items():\n",
    "    X_train, y_train = generate_data(n, mu, STD)\n",
    "    m_bucket = value['M_BUCKET']\n",
    "    \n",
    "    _, _, linear_accuracy = get_tuned_model(\n",
    "        LinearTrainer,\n",
    "        X_train,\n",
    "        y_train,\n",
    "        X_test,\n",
    "        y_test,\n",
    "        LINEAR_PARAM_GRID,\n",
    "        DIM,\n",
    "        LINEAR_NUM_COMBINATION,\n",
    "    )\n",
    "    \n",
    "    print(f'linear model accuracy at N={n} = {round(linear_accuracy, 4)}')\n",
    "    robust_linear_accuracies = []\n",
    "    for m in m_bucket:\n",
    "        X_unlabeled, _ = generate_unlabeled_data(m, mu, STD)\n",
    "        _, _, robust_linear_accuracy = get_tuned_model(\n",
    "            RobustTrainer,\n",
    "            X_train,\n",
    "            y_train,\n",
    "            X_test,\n",
    "            y_test,\n",
    "            ROBUST_LINEAR_PARAM_GRID,\n",
    "            DIM,\n",
    "            ROBUST_LINEAR_NUM_COMBINATION,\n",
    "            X_unlabeled\n",
    "        )\n",
    "        robust_linear_accuracies.append(robust_linear_accuracy)\n",
    "        print(f'robust linear model accuracy at N={n}, M={m} is equal to {round(robust_linear_accuracy, 4)}')\n",
    "    \n",
    "    \n",
    "    training_outputs[n] = {}\n",
    "    training_outputs[n] = {\n",
    "        'linear_model_accuracy': linear_accuracy,\n",
    "        'robust_linear_model_accuracies': robust_linear_accuracies\n",
    "    }\n",
    "    print('***********************************')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate Upper Bound Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_max, y_train_max = generate_data(N_MAX, mu, STD)\n",
    "_, _, upper_bound_accuracy = get_tuned_model(\n",
    "    LinearTrainer,\n",
    "    X_train_max,\n",
    "    y_train_max,\n",
    "    X_test,\n",
    "    y_test,\n",
    "    LINEAR_PARAM_GRID,\n",
    "    DIM,\n",
    "    LINEAR_NUM_COMINATION,\n",
    ")\n",
    "print(f'upper bound accuracy  = {round(upper_bound_accuracy, 4)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "def plot_result(training_outputs, upper_bound_accuracy, xscale=None):\n",
    "    fig, ax = plt.subplots(figsize=(12,8))\n",
    "    colors = ['blue', 'red', 'green']\n",
    "    whole_interval = [0, 10000]\n",
    "    i = 0\n",
    "    for n, value in training_outputs.items():\n",
    "        y, y_0 = value['robust_linear_model_accuracies'], value['linear_model_accuracy']\n",
    "        x = TRAINING_CONFIG[n]['M_BUCKET']\n",
    "        sns.lineplot(x=whole_interval, y=[y_0]*len(whole_interval), ax=ax, color=f'tab:{colors[i]}', label=f'N={n}, linear')\n",
    "        sns.lineplot(x=x, y=y, ax=ax,linestyle='dashed', color=f'tab:{colors[i]}', label=f'N={n}, robust linear')\n",
    "        sns.scatterplot(x=x, y=y, ax=ax, color=f'tab:{colors[i]}', s=10)\n",
    "        i+=1\n",
    "\n",
    "    sns.lineplot(x=whole_interval, y=[upper_bound_accuracy]*len(whole_interval), ax=ax, color=f'black', label=f'N=10000, linear')\n",
    "    ax.set(xscale=xscale, xlabel='Number OF Unlabeled', ylabel='accuracy', title=f'accuracy vs unlabaled count for dim={DIM}, ε=0')    \n",
    "    plt.legend(loc='best')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    \n",
    "plot_result(training_outputs, upper_bound_accuracy, xscale='linear')\n",
    "plot_result(training_outputs, upper_bound_accuracy, xscale='log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
