{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a dataset with gaussian clusters in 2D- with linearly and nonlinearly separable clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import sys\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "sys.path.append('../..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CHOOSE WHERE TO SAVE THE DATA\n",
    "dataset_dir = './data/2dgaussian_diffmag/' #change this to your location\n",
    "if not os.path.exists(dataset_dir): #create directory if it doesn't exist\n",
    "    os.makedirs(dataset_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "#hyperparameters\n",
    "dimspace = 2 \n",
    "numpoints_perconc = 1000000*(dimspace//2) #points per concept\n",
    "numclusters = 6 #number of concepts\n",
    "intrinsic_dims = 2*torch.ones((numclusters,), dtype=int) #all concepts are 2D here\n",
    "\n",
    "#choose centers of clusters\n",
    "K = 1\n",
    "Kc = K/dimspace #scale of centers\n",
    "rad = Kc*torch.tensor([3.0, 1.0, 3.0, 1.0, 3.0, 1.0]) #radius (magnitude) of clusters\n",
    "angles = torch.arange(0, 2*torch.pi, 2*torch.pi/numclusters)\n",
    "centers = torch.stack((rad*torch.cos(angles), rad*torch.sin(angles)), dim=1)\n",
    "\n",
    "# choose variances of clusters\n",
    "scaler_alpha = 4.5\n",
    "Qv = 1/(2**scaler_alpha)\n",
    "Kv = Qv*K/intrinsic_dims.float().max() #variance per dimension\n",
    "torch.manual_seed(625)\n",
    "variances = [Kv*torch.ones((intrinsic_dims[i],)) for i in range(numclusters)]\n",
    "Covmats = [1e-6*torch.eye(dimspace) + torch.diag(variances[i]) for i in range(numclusters)] #isotropic\n",
    "truefeatures = {'centers': centers, 'variances': variances}\n",
    "\n",
    "#sample gaussian clusters from centers and variances\n",
    "data_all = torch.zeros((numclusters*numpoints_perconc, dimspace))\n",
    "class_id_all = torch.zeros((numclusters*numpoints_perconc,), dtype=int)\n",
    "for k in range(numclusters):\n",
    "    clusterk = MultivariateNormal(centers[k,:], Covmats[k])\n",
    "    data_all[k*numpoints_perconc:(k+1)*numpoints_perconc, :] = clusterk.sample((numpoints_perconc,))\n",
    "    class_id_all[k*numpoints_perconc:(k+1)*numpoints_perconc] = k\n",
    "numpoints_total = data_all.shape[0]\n",
    "\n",
    "#check separability using logistic regression- can skip \n",
    "X_train, X_test, y_train, y_test = train_test_split(data_all, class_id_all, test_size=0.33, random_state=42)\n",
    "clf = LogisticRegression(random_state=0, max_iter=1000, C=1e-1,penalty='l2').fit(X_train, y_train)\n",
    "score = clf.score(X_test, y_test)\n",
    "#get score separately for each concept (one vs all)\n",
    "scoresperconcept = []\n",
    "for k in range(numclusters):\n",
    "    y_train_concept = (y_train == k)\n",
    "    y_test_concept = (y_test == k)\n",
    "    clf = LogisticRegression(random_state=0, max_iter=1000, C=1e-1,penalty='l2').fit(X_train, y_train_concept)\n",
    "    score = clf.score(X_test, y_test_concept)\n",
    "    scoresperconcept.append(score)\n",
    "print(f\"Scaler={scaler_alpha}, Score={score}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize clusters\n",
    "numpoints_viz = 1000\n",
    "id_viz = torch.randint(0, numpoints_total, (numpoints_viz,))\n",
    "plt.scatter(data_all[id_viz,0], data_all[id_viz,1], c=class_id_all[id_viz], cmap='tab10')\n",
    "plt.title(f\"Data with different magnitudes\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(41)\n",
    "\n",
    "#shuffle data\n",
    "shuffle_indices = torch.randperm(numpoints_total)\n",
    "datax = data_all[shuffle_indices,:]\n",
    "classidx = class_id_all[shuffle_indices]\n",
    "\n",
    "#CREATE TRAIN, TEST SPLITS\n",
    "torch.manual_seed(4)\n",
    "frac_train = 0.7 #70% train, 30% test\n",
    "total_points = numpoints_total\n",
    "train_data_size = int(frac_train*total_points)\n",
    "test_data_size = total_points-train_data_size\n",
    "random_ordering = torch.randperm(total_points)\n",
    "train_indices = random_ordering[:train_data_size]\n",
    "test_indices = random_ordering[train_data_size:]\n",
    "\n",
    "train_datax = datax[train_indices,:]\n",
    "test_datax = datax[test_indices,:]\n",
    "train_classidx = classidx[train_indices]\n",
    "test_classidx = classidx[test_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAVE DATA\n",
    "\n",
    "#location to save data\n",
    "# labdir = os.environ['USERDIR']\n",
    "# data_loc = labdir+'/data/'\n",
    "# dataset_dir = data_loc+f'/2dgaussian_diffmag/'\n",
    "dim = dimspace #data dimension\n",
    "\n",
    "torch.save({'numclusters':numclusters,\\\n",
    "            'dim':dim,\\\n",
    "            'data':train_datax,\\\n",
    "            'labels':train_classidx,\\\n",
    "            'truefeatures':truefeatures}, dataset_dir+f'traindata.pt')\n",
    "\n",
    "torch.save({'numclusters':numclusters,\\\n",
    "            'dim':dim,\\\n",
    "            'data':test_datax,\\\n",
    "            'labels':test_classidx,\\\n",
    "            'truefeatures':truefeatures}, dataset_dir+f'testdata.pt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
