{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a dataset with heterogenous gaussian clusters in 128D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import sys\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "sys.path.append('../..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CHOOSE WHERE TO SAVE DATA\n",
    "dataset_dir = f'./data/highdgaussian_intrinsicdim/' #change to your desired directory\n",
    "if not os.path.exists(dataset_dir): #create directory if it doesn't exist\n",
    "    os.makedirs(dataset_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hyperparameters\n",
    "numclusters = 5\n",
    "dimspace = 128 #\n",
    "numpoints_perconc = 100000*(dimspace//2) #points per concept\n",
    "intrinsic_dims = torch.tensor([2**i - 2 for i in range(3, 3+numclusters)]) #each concept has different dimensionality\n",
    "\n",
    "#choose centers (means) of clusters\n",
    "K = 1/50\n",
    "Q = 300\n",
    "Kc = Q*K/dimspace\n",
    "torch.manual_seed(500)\n",
    "centers = Kc*torch.rand(numclusters, dimspace)\n",
    "\n",
    "#choose variances (covariances- isotropic) of clusters\n",
    "Kv = K/intrinsic_dims.float() #per-dimension variance is inversely proportional to intrinsic dim\n",
    "variances = [torch.cat((Kv[i]*torch.ones((intrinsic_dims[i],)), torch.zeros((dimspace-intrinsic_dims[i],)))) for i in range(numclusters)]\n",
    "Covmats = [1e-6*torch.eye(dimspace) + torch.diag(variances[i]) for i in range(numclusters)]\n",
    "truefeatures = {'centers': centers, 'variances': variances, 'intrinsic_dims': intrinsic_dims}\n",
    "#sample multivariate gaussians from mean and covariance\n",
    "torch.manual_seed(754)\n",
    "\n",
    "\n",
    "data_all = torch.zeros((numclusters*numpoints_perconc, dimspace))\n",
    "class_id_all = torch.zeros((numclusters*numpoints_perconc,), dtype=int)\n",
    "for k in range(numclusters):\n",
    "    clusterk = MultivariateNormal(centers[k,:], Covmats[k])\n",
    "    data_all[k*numpoints_perconc:(k+1)*numpoints_perconc, :] = clusterk.sample((numpoints_perconc,))\n",
    "    class_id_all[k*numpoints_perconc:(k+1)*numpoints_perconc] = k\n",
    "numpoints_total = data_all.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#visualize 2D PCA of data\n",
    "# pca = PCA(n_components=2)\n",
    "# pca.fit(data_all)\n",
    "# data_all_pca = pca.transform(data_all)\n",
    "# for k in range(numclusters):\n",
    "#     plt.scatter(data_all_pca[class_id_all==k,0], data_all_pca[class_id_all==k,1], label=f\"dim={str(intrinsic_dims[k].item())}\")\n",
    "# plt.legend()\n",
    "# plt.title(f\"PCA of {dimspace}dim data\")\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "#shuffle data\n",
    "torch.manual_seed(41)\n",
    "shuffle_indices = torch.randperm(numpoints_total)\n",
    "data_all = data_all[shuffle_indices,:]\n",
    "class_id_all = class_id_all[shuffle_indices]\n",
    "\n",
    "#CREATE TRAIN, TEST SPLITS\n",
    "torch.manual_seed(4)\n",
    "frac_train = 0.7 #70% train, 30% test\n",
    "total_points = numpoints_total\n",
    "train_data_size = int(frac_train*total_points)\n",
    "test_data_size = total_points-train_data_size\n",
    "random_ordering = torch.randperm(total_points)\n",
    "train_indices = random_ordering[:train_data_size]\n",
    "test_indices = random_ordering[train_data_size:]\n",
    "\n",
    "train_data_all = data_all[train_indices,:]\n",
    "test_data_all = data_all[test_indices,:]\n",
    "train_class_id_all = class_id_all[train_indices]\n",
    "test_class_id_all = class_id_all[test_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SAVE DATA\n",
    "#location to save data\n",
    "# labdir = os.environ['USERDIR']\n",
    "# data_loc = labdir+'/data/'\n",
    "# dataset_dir = data_loc+f'/{dimspace}dgaussian_intrinsicdim/'\n",
    "dim = dimspace #data dimension\n",
    "\n",
    "\n",
    "\n",
    "torch.save({'numclusters':numclusters,\\\n",
    "            'dim':dim,\\\n",
    "            'data':train_data_all,\\\n",
    "            'labels':train_class_id_all,\\\n",
    "            'truefeatures':truefeatures}, dataset_dir+f'traindata.pt')\n",
    "\n",
    "torch.save({'numclusters':numclusters,\\\n",
    "            'dim':dim,\\\n",
    "            'data':test_data_all,\\\n",
    "            'labels':test_class_id_all,\\\n",
    "            'truefeatures':truefeatures}, dataset_dir+f'testdata.pt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
