{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f7f6cbee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "import pandas as pd\n",
    "import copy\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.cluster import KMeans\n",
    "from matplotlib.colors import ListedColormap\n",
    "\n",
    "\n",
    "### Data Generation -- Gaussian Mixture\n",
    "def gen_data(\n",
    "    k=2,\n",
    "    dim=2,\n",
    "    points_per_cluster=1000,\n",
    "    means=None,\n",
    "    covariances=None\n",
    "):\n",
    "    # Generate data from specified or random Gaussians\n",
    "    if means is not None and covariances is None:\n",
    "        raise ValueError(\"Covariances must be provided if means are provided.\")\n",
    "    if means is None:\n",
    "        means = [np.random.rand(dim) for _ in range(k)]\n",
    "    if covariances is None:\n",
    "        covariances = [np.eye(dim) for _ in range(k)]\n",
    "    x = []\n",
    "    y = []\n",
    "    for i in range(k):\n",
    "        _x = np.random.multivariate_normal(means[i], covariances[i], points_per_cluster)\n",
    "        x += list(_x)\n",
    "        y += [i] * points_per_cluster\n",
    "    x = np.array(x)\n",
    "    y = np.array(y)\n",
    "    return x, y\n",
    "\n",
    "def create_positive_unlabeled(\n",
    "    x,\n",
    "    y,\n",
    "    num_labeled,\n",
    "    means=None,\n",
    "    covariances=None,\n",
    "    setting='single_data'\n",
    "):\n",
    "    '''\n",
    "    Creates positive examples (P) and unlabeled examples (0) based on given x and y.\n",
    "    input:\n",
    "        - x: Input data (total_examples, features)\n",
    "        - y: True labels (total_examples,)\n",
    "        - num_labeled: Number of positive examples (P)\n",
    "    output:\n",
    "        - labels: Generated labels (total_examples,)\n",
    "    '''\n",
    "    if setting == 'single_data':\n",
    "        feat = x\n",
    "        labels = np.zeros(len(y))               # Initialize labels with all 0 (unlabeled)\n",
    "        positive_indices = np.where(y == 1)[0]  # Indices of positive examples in the true labels\n",
    "        selected_indices = np.random.choice(positive_indices, size=num_labeled, replace=False)\n",
    "        labels[selected_indices] = 1            # Assign label 1 to selected positive examples\n",
    "\n",
    "    elif setting == 'case_control':\n",
    "        x_, y_ = gen_data(\n",
    "            k=2,\n",
    "            dim=2,\n",
    "            points_per_cluster=num_labeled,\n",
    "            means=means,\n",
    "            covariances=covariances\n",
    "        )\n",
    "        ix = np.where(y_ == 1)[0]\n",
    "        x_P = x_[ix]\n",
    "        feat = np.concatenate((x, x_P), axis=0)\n",
    "        labels_P = np.ones(num_labeled)\n",
    "        labels_U = np.zeros(len(y))\n",
    "        labels = np.concatenate((labels_U, labels_P), axis=0)\n",
    "\n",
    "\n",
    "    return feat, labels\n",
    "\n",
    "\n",
    "# Classification Code\n",
    "def train_linear_model(\n",
    "    x,\n",
    "    y,\n",
    "    loss_fn='ce',\n",
    "    prior=0.5,\n",
    "    verbose=True,\n",
    "    lr=0.1,\n",
    "    num_epochs=100\n",
    "):\n",
    "    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)\n",
    "    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
    "    y_train_tensor = torch.tensor(y_train, dtype=torch.long)\n",
    "\n",
    "    model = nn.Linear(2, 2)\n",
    "    if loss_fn == 'ce':\n",
    "        criterion = nn.CrossEntropyLoss()\n",
    "    elif loss_fn in ['uPU', 'nnPU']:\n",
    "        criterion = PULoss(prior= prior, loss_fn=loss_fn)\n",
    "    else:\n",
    "      raise NotImplementedError\n",
    "\n",
    "    optimizer = optim.SGD(model.parameters(), lr=lr)\n",
    "    num_epochs = num_epochs\n",
    "    loss_values = []\n",
    "    acc = []\n",
    "    for epoch in range(num_epochs):\n",
    "        outputs = model(X_train_tensor)\n",
    "        loss = criterion(outputs, y_train_tensor)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        loss_values.append(loss.item())\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "\n",
    "            X_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n",
    "            with torch.no_grad():\n",
    "                logits = model(X_test_tensor)\n",
    "                predicted_labels = torch.argmax(logits, dim=1)\n",
    "            accuracy = accuracy_score(y_test, predicted_labels.numpy())\n",
    "            acc.append(accuracy)\n",
    "            if verbose:\n",
    "              print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')\n",
    "              print(f\"Accuracy: {accuracy:.4f}\")\n",
    "\n",
    "    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n",
    "    with torch.no_grad():\n",
    "        logits = model(X_test_tensor)\n",
    "        predicted_labels = torch.argmax(logits, dim=1)\n",
    "    accuracy = accuracy_score(y_test, predicted_labels.numpy())\n",
    "    print(f\"Accuracy: {accuracy:.4f}\")\n",
    "\n",
    "    return model, loss_values, acc\n",
    "\n",
    "\n",
    "\n",
    "class PULoss(nn.Module):\n",
    "    def __init__(self, prior, loss_fn: str):\n",
    "        super(PULoss, self).__init__()\n",
    "        if not 0 < prior < 1:\n",
    "            raise ValueError(\"The class prior should be in [0, 1]\")\n",
    "        self.prior, self.loss_fn = prior, loss_fn\n",
    "        self.meta_loss = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, logits, targets):\n",
    "        # logits: shape [Batch Size \\times Num of Classes] - un-normalized raw linear combination (w_i * x_i + b)\n",
    "        ix_positive = torch.where(targets == 1)[0]\n",
    "        ix_unlabeled = torch.where(targets == 0)[0]\n",
    "\n",
    "        pos_logits = torch.index_select(input=logits, dim=0, index=ix_positive)\n",
    "        unlabeled_logits = torch.index_select(input=logits, dim=0, index=ix_unlabeled)\n",
    "\n",
    "        targets_pos = torch.ones(len(ix_positive), dtype=targets.dtype)\n",
    "        targets_pos_inverse = torch.zeros(len(ix_positive), dtype=targets.dtype)\n",
    "        targets_unlabeled = torch.zeros(len(ix_unlabeled), dtype=targets.dtype)\n",
    "\n",
    "        # compute empirical estimates\n",
    "        # R_p+\n",
    "        loss_positive = self.meta_loss(pos_logits.to(logits.device), targets_pos.to(targets.device)) \\\n",
    "            if ix_positive.nelement() != 0 else 0\n",
    "        # R_u-\n",
    "        loss_unlabeled = self.meta_loss(unlabeled_logits.to(logits.device), targets_unlabeled.to(targets.device)) \\\n",
    "            if ix_unlabeled.nelement() != 0 else 0\n",
    "        # R_p-\n",
    "        loss_pos_inv = self.meta_loss(pos_logits.to(logits.device), targets_pos_inverse.to(targets.device)) \\\n",
    "            if ix_positive.nelement() != 0 else 0\n",
    "        # (1-pi) Rn- = R_u- - prior * R_p-\n",
    "        loss_negative = loss_unlabeled - self.prior * loss_pos_inv\n",
    "\n",
    "        if self.prior == 0:\n",
    "            prior = ix_positive.nelement() / (ix_positive.nelement() + ix_unlabeled.nelement())\n",
    "            # i.e. fully supervised equivalent to PN strategy\n",
    "            return prior * loss_unlabeled + (1 - prior) * loss_positive\n",
    "        elif self.loss_fn == 'nnPU':\n",
    "            return - loss_negative if loss_negative < 0 else self.prior * loss_positive + loss_negative\n",
    "        elif self.loss_fn == 'uPU':\n",
    "            return self.prior * loss_positive + loss_negative\n",
    "        else:\n",
    "            ValueError('Unsupported Loss')\n",
    "\n",
    "\n",
    "def puPL(x_PU, y_PU, num_clusters=2):\n",
    "    p_ix = y_PU==1\n",
    "    u_ix = y_PU==0\n",
    "    x_P = x_PU[p_ix]\n",
    "    x_U = x_PU[u_ix]\n",
    "\n",
    "    ## Initialize Cluster Centers ##\n",
    "    # Compute the mean of x_P as the first centroid: Note centroid_2 since we use P:1 N/U:0\n",
    "    centroid_2 = np.mean(x_P, axis=0)\n",
    "    # Use K-means++ to choose the second center from x_U given the first center\n",
    "    kmeans_pp = KMeans(n_clusters=1, init=np.array([centroid_2]))\n",
    "    kmeans_pp.fit(x_U)\n",
    "    centroid_1 = kmeans_pp.cluster_centers_[0]\n",
    "    centroids = np.array([centroid_1, centroid_2]) # Initialize the centroids with the computed values\n",
    "\n",
    "    ## Perform K-means clustering with the initialized centroids\n",
    "    kmeans = KMeans(n_clusters=num_clusters, init=centroids)\n",
    "    kmeans.fit(np.concatenate((x_U, x_P), axis=0))\n",
    "\n",
    "    labels = kmeans.labels_\n",
    "    data = np.concatenate((x_U, x_P), axis=0)\n",
    "\n",
    "    return labels, data\n",
    "\n",
    "### Viz Data ---- Scatter Plot\n",
    "def plot_scatter(x, y, label_colors, model=None, lbl_map=None):\n",
    "    # Plots a scatter plot of the generated data\n",
    "    data = {'X': x[:, 0], 'Y': x[:, 1], 'Label': y}\n",
    "    df = pd.DataFrame(data)\n",
    "    if lbl_map is not None:\n",
    "        df['Label'] = df['Label'].map(lbl_map)  # Map labels 0 to 'N' and 1 to 'P'\n",
    "    sns.scatterplot(\n",
    "        data=df,\n",
    "        x='X',\n",
    "        y='Y',\n",
    "        hue='Label',\n",
    "        palette=label_colors,\n",
    "        s=60,\n",
    "        alpha=0.8,\n",
    "        edgecolor='k',\n",
    "        marker='o',\n",
    "    )\n",
    "\n",
    "def plot_decision_boundary(X, y, model,\n",
    "                           boundary_color='k',\n",
    "                           boundary_linestyle='-'):\n",
    "    # Generate a grid of points to plot the decision boundary\n",
    "    x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1\n",
    "    y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1\n",
    "    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 1000),\n",
    "                         np.linspace(y_min, y_max, 1000))\n",
    "    grid_points = np.c_[xx.ravel(), yy.ravel()]\n",
    "    grid_points_tensor = torch.tensor(grid_points, dtype=torch.float32)\n",
    "    with torch.no_grad():\n",
    "        logits = model(grid_points_tensor)\n",
    "        predicted_labels = torch.argmax(logits, dim=1)\n",
    "    # Reshape the predicted labels to match the grid shape\n",
    "    decision_map = predicted_labels.reshape(xx.shape)\n",
    "    CS = plt.contour(xx, yy,\n",
    "                         decision_map,\n",
    "                         alpha=0.5,\n",
    "                         colors=boundary_color,\n",
    "                         linestyles=boundary_linestyle\n",
    "                    )\n",
    "    # plt.clabel(CS, inline=1, fontsize=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d5dd2d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# supervised\n",
    "# Generate Data (PN) from Gaussians\n",
    "mean1 = np.array([0, 0])  # Mean of first cluster\n",
    "cov1 = np.array([[0, 0.5], [0.5, 0]])  # Covariance matrix of first cluster\n",
    "\n",
    "# mean2 = np.array([3.5, 3.5])  # Mean of second cluster\n",
    "# cov2 = np.array([[1, 0.5], [0.5, 2]])  # Covariance matrix of second cluster\n",
    "mean2 = np.array([3.5, 3.5])  # Mean of second cluster\n",
    "cov2 = np.array([[0, 0.5], [1, 0]])  # Covariance matrix of second cluster\n",
    "\n",
    "means = [mean1, mean2]\n",
    "covariances = [cov1, cov2]\n",
    "point_per_cluster = 1500\n",
    "\n",
    "x, y = gen_data(\n",
    "    k=2,\n",
    "    dim=2,\n",
    "    points_per_cluster=point_per_cluster,\n",
    "    means=means,\n",
    "    covariances=covariances\n",
    ")\n",
    "model_opt, loss_opt, acc_opt = train_linear_model(x, y, loss_fn='ce', lr=0.1, num_epochs=1000, verbose=False)\n",
    "plt.figure(figsize=(8,6), dpi=300)\n",
    "label_colors = [\n",
    "    '#d62728',\n",
    "    '#1f77b4'\n",
    "]\n",
    "lbl_map = {\n",
    "    0: r'$y^\\ast = 0$',\n",
    "    1: r'$y^\\ast = 1$'\n",
    "}\n",
    "plot_scatter(x, y, label_colors, lbl_map=lbl_map)\n",
    "plot_decision_boundary(x, y, model_opt, boundary_color='k')\n",
    "plt.annotate(\n",
    "    r'CE$^\\ast$',\n",
    "    xy=(4.7, -0.9),\n",
    "    fontsize=12,\n",
    "    fontweight='bold',\n",
    "    color='k'\n",
    ")\n",
    "plt.xlabel('X0', fontweight='bold')\n",
    "plt.ylabel('X1', fontweight='bold')\n",
    "plt.legend(loc='upper left', ncol=3, fontsize=20, prop={'weight': 'bold', 'size':'large'}, bbox_to_anchor=(0, 1.1))\n",
    "plt.grid(True, linestyle='--', alpha=0.6)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8164619a",
   "metadata": {},
   "outputs": [],
   "source": [
    "nP_L = 500\n",
    "# prior_true = (point_per_cluster - nP_L) / (point_per_cluster*2 - nP_L)\n",
    "prior_true=0.5\n",
    "x_PU, y_PU = create_positive_unlabeled(\n",
    "    x,\n",
    "    y,\n",
    "    num_labeled=nP_L,\n",
    "    means=means,\n",
    "    covariances=covariances,\n",
    "    setting='case_control'\n",
    "    )\n",
    "model_opt, loss_opt, acc_opt = train_linear_model(x, y, loss_fn='ce', lr=0.1, num_epochs=1000, verbose=False)\n",
    "# CE\n",
    "model_CE, loss_CE, acc_CE = train_linear_model(x_PU, y_PU, loss_fn='ce', lr=0.1, num_epochs=1000, verbose=False)\n",
    "# # uPU\n",
    "# model_uPU, loss_uPU, acc_uPU = train_linear_model(x_PU, y_PU, loss_fn='uPU', lr=0.0001, prior=prior_true, num_epochs=100, verbose=False)\n",
    "# nnPU\n",
    "model_nnPU, loss_nnPU, acc_nnPU = train_linear_model(x_PU, y_PU, loss_fn='nnPU', lr=0.005, prior=prior_true, num_epochs=100, verbose=False)\n",
    "# puPL + CE\n",
    "y_hat_PU, x_hat_PU = puPL(x_PU, y_PU, num_clusters=2)                                                           # get pseudo-labels\n",
    "model_puPL, loss_puPL, acc_puPL = train_linear_model(x_PU, y_hat_PU, loss_fn='ce', lr=0.1, num_epochs=1000, verbose=False)  # train with puPL + CE"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
