{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import torch\n",
    "import pandas as pd\n",
    "import scanpy as sc\n",
    "from scipy import stats\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "import umap\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "from omegaconf import OmegaConf\n",
    "from AutoEncoder_models import MLPAE\n",
    "from train_test import train_and_eval_per_epoch\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from torch.utils.data import TensorDataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def compute_segregation(samples, labels, k=10):\n",
    "    \"\"\"\n",
    "    Compute segregation measure for given samples and labels.\n",
    "\n",
    "    Parameters:\n",
    "        samples (numpy.ndarray): Array of shape (n_samples, n_features) containing the samples.\n",
    "        labels (numpy.ndarray): Array of shape (n_samples,) containing the labels (0 or 1) for each sample.\n",
    "        k (int): Number of nearest neighbors to consider.\n",
    "\n",
    "    Returns:\n",
    "        float: Segregation measure.\n",
    "    \"\"\"\n",
    "    # Initialize NearestNeighbors object\n",
    "    nbrs = NearestNeighbors(n_neighbors=k).fit(samples)\n",
    "\n",
    "    segregation_sum = 0\n",
    "    n_samples = len(samples)\n",
    "\n",
    "    # Loop through each sample\n",
    "    for i in tqdm(range(n_samples)):\n",
    "        # Get indices and distances of k nearest neighbors\n",
    "        distances, indices = nbrs.kneighbors([samples[i]])\n",
    "\n",
    "        # Count how many neighbors are from the same batch\n",
    "        same_batch_count = sum(1 for idx in indices[0] if labels[idx] == labels[i])\n",
    "\n",
    "        # Add the proportion of same batch neighbors to segregation_sum\n",
    "        segregation_sum += same_batch_count / k\n",
    "\n",
    "    # Compute average segregation measure\n",
    "    segregation_measure = segregation_sum / n_samples\n",
    "\n",
    "    return segregation_measure"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def umap_scatter_gaussian(X, y, shift, hidden_dim, learn=False, save=False):\n",
    "    features_embedded = umap.UMAP(n_components=2, random_state=42, metric='cosine').fit_transform(X)\n",
    "    if learn:\n",
    "        plt.figure(figsize=(8, 8), dpi=300)\n",
    "        segregation = compute_segregation(X, y)\n",
    "        plt.text(0.5, 1.03, f'UMAP for hidden layer size = {hidden_dim}', horizontalalignment='center', fontsize=22, transform=plt.gca().transAxes)\n",
    "        plt.text(0.5, -0.05, f'KNN-DAT = {round(segregation, 2)}', horizontalalignment='center', fontsize=22, transform=plt.gca().transAxes)\n",
    "    else:\n",
    "        plt.figure(figsize=(12, 8), dpi=300)\n",
    "        #plt.title(f'UMAP of linear subspace data for source and target with shift = {shift}', fontsize=22)\n",
    "    plt.scatter(*zip(*features_embedded[np.where(y==0)]), marker='o', color='b', s=4, alpha=1, label='Source')\n",
    "    plt.scatter(*zip(*features_embedded[np.where(y==1)]), marker='o', color='r', s=4, alpha=1, label='Target')\n",
    "    plt.tick_params(left=False, right=False , labelleft=False , labelbottom=False, bottom=False)\n",
    "    plt.legend(loc='upper right', fontsize=20)\n",
    "    plt.grid(False)\n",
    "    if save:\n",
    "        if learn:\n",
    "            plt.savefig(f'images/domain_shift/umap_after_learning_shift_{shift}_hidden_dim_{hidden_dim}.png', dpi=300, bbox_inches='tight')\n",
    "        else:\n",
    "            plt.savefig(f'images/domain_shift/umap_before_learning_shift_{shift}.png', dpi=300, bbox_inches='tight')\n",
    "    plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def gaussian_domain_shift_data_loader(n_samples, out_dim, latent_dim, batch_size, noise_per=0, snr_db=0, shift=1):\n",
    "    n_test_samples = 10_000\n",
    "    n_val_train_samples = 5000\n",
    "    n_val_test_samples = 5000\n",
    "    n_samples += n_test_samples + n_val_train_samples + n_val_test_samples\n",
    "    rand_mat = np.random.randn(out_dim, latent_dim)\n",
    "    snr = 10 ** (snr_db / 20)\n",
    "    domain_shift_mat = rand_mat + shift * np.random.randn(out_dim, latent_dim)\n",
    "    rand_mat *= snr / latent_dim ** 0.5\n",
    "    domain_shift_mat *= snr / ((shift ** 2 + 1) * latent_dim) ** 0.5\n",
    "\n",
    "    latent_features = np.random.randn(n_samples, latent_dim)\n",
    "\n",
    "    latent_train_data, latent_test_data = train_test_split(latent_features, test_size=n_test_samples + n_val_test_samples, random_state=0)\n",
    "    latent_train_data, latent_val_train_data = train_test_split(latent_train_data, test_size=n_val_train_samples, random_state=0)\n",
    "    latent_test_data, latent_val_test_data = train_test_split(latent_test_data, test_size=n_val_test_samples, random_state=0)\n",
    "    train_data = (latent_train_data @ domain_shift_mat.T).astype('float32')\n",
    "    train_val_data = torch.tensor((latent_val_train_data @ domain_shift_mat.T).astype('float32'))\n",
    "    test_data = (latent_test_data @ rand_mat.T).astype('float32')\n",
    "    test_val_data = torch.tensor((latent_val_test_data @ rand_mat.T).astype('float32'))\n",
    "\n",
    "    n_samples_to_noise = int(noise_per * len(train_data))\n",
    "    samples_to_noise = random.sample(range(len(train_data)), n_samples_to_noise)\n",
    "    train_data[samples_to_noise] += np.random.normal(loc=0, scale=1, size=(n_samples_to_noise, out_dim)).astype('float32')\n",
    "\n",
    "    train_set = TensorDataset(torch.Tensor(train_data))\n",
    "    test_set = TensorDataset(torch.Tensor(test_data))\n",
    "    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
    "    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)\n",
    "    return train_loader, train_val_data, test_loader, test_val_data"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Train model if needed:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "config = OmegaConf.load('config.yaml')\n",
    "seed = 0\n",
    "batch_size = 10\n",
    "lr = 0.001\n",
    "opt = 'Adam'\n",
    "epochs = 200\n",
    "n_samples = 5000\n",
    "higher_dim = 50\n",
    "noise_per = 0\n",
    "snr_db = 20\n",
    "scenario = 'domain_shift'\n",
    "shift = 3#4\n",
    "cuda = config.cuda\n",
    "\n",
    "latent_dim = 45\n",
    "hidden_dim = 100\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "train_loader, train_val_data, test_loader, test_val_data = gaussian_domain_shift_data_loader(n_samples, higher_dim, 20, batch_size, noise_per, snr_db, shift)\n",
    "torch.save(train_val_data, f\"domain_shift_models/train_val_shift_{shift}_noise_per_{noise_per}_snr_{snr_db}_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")\n",
    "torch.save(test_val_data, f\"domain_shift_models/test_val_shift_{shift}_noise_per_{noise_per}_snr_{snr_db}_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")\n",
    "\n",
    "print(f'num train={len(train_loader.dataset)}')\n",
    "print(f'hidden_dim={hidden_dim}')\n",
    "print(f'latent_dim={latent_dim}')\n",
    "print(f'noise_per={noise_per}')\n",
    "print(f'snr_db={snr_db}')\n",
    "print(f'seed={seed}')\n",
    "print(f'shift={shift}')\n",
    "\n",
    "model = MLPAE(input_dim=train_loader.dataset[0][0].shape[0], latent_dim=latent_dim, hidden_dim=hidden_dim,\n",
    "               n_hidden_layers=0, final_activation=nn.Identity())\n",
    "\n",
    "total_params = sum(param.numel() for param in model.parameters())\n",
    "print(f'total_params={total_params}')\n",
    "\n",
    "for epoch in tqdm(range(epochs)):\n",
    "    train_loss, test_loss, _, _ = train_and_eval_per_epoch(train_loader, test_loader, model, opt, lr)\n",
    "    if (epoch + 1) % 10 == 0:\n",
    "        print(f'Epoch: {epoch + 1}, Train loss: {train_loss}, test loss: {test_loss}')\n",
    "\n",
    "torch.save(model.state_dict(), f\"domain_shift_models/model_shift_{shift}_noise_per_{noise_per}_snr_{snr_db}_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Prepare latent vectors:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "hidden_dim = 100\n",
    "noise_per = 0\n",
    "snr_db = 20\n",
    "shift = 3\n",
    "latent_dim = 45\n",
    "train_samples = torch.load(f\"domain_shift_models/train_val_shift_{shift}_noise_per_{noise_per}_snr_{snr_db}_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")\n",
    "test_samples = torch.load(f\"domain_shift_models/test_val_shift_{shift}_noise_per_{noise_per}_snr_{snr_db}_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")\n",
    "model = MLPAE(input_dim=train_samples[0].shape[0], latent_dim=latent_dim, hidden_dim=hidden_dim, n_hidden_layers=0, final_activation=nn.Identity())\n",
    "model.load_state_dict(torch.load(f\"domain_shift_models/model_shift_{shift}_noise_per_{noise_per}_snr_{snr_db}_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\"))\n",
    "train_latent_vecs = np.array([])\n",
    "test_latent_vecs = np.array([])\n",
    "with torch.no_grad():\n",
    "    model.eval()\n",
    "    for train_sample, test_sample in zip(train_samples, test_samples):\n",
    "        _, train_latent_vec = model(train_sample)\n",
    "        _, test_latent_vec = model(test_sample)\n",
    "        if len(train_latent_vecs) == 0:\n",
    "            train_latent_vecs = train_latent_vec.numpy()\n",
    "        else:\n",
    "            train_latent_vecs = np.vstack((train_latent_vecs, train_latent_vec.numpy()))\n",
    "        if len(test_latent_vecs) == 0:\n",
    "            test_latent_vecs = test_latent_vec.numpy()\n",
    "        else:\n",
    "            test_latent_vecs = np.vstack((test_latent_vecs, test_latent_vec.numpy()))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### UMAP of Data before model learning:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "data = np.concatenate((train_samples, test_samples))\n",
    "labels = np.concatenate((np.zeros(len(train_samples)), np.ones(len(test_samples))))\n",
    "umap_scatter_gaussian(data, labels, shift, hidden_dim, learn=False, save=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compute_segregation(data, labels, k=10)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### UMAP of latent vectors from model:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "hidden = 4, latent = 45, snr = 20, noise per = 0\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "latent_data = np.concatenate((train_latent_vecs, test_latent_vecs))\n",
    "latent_labels = np.concatenate((np.zeros(len(train_latent_vecs)), np.ones(len(test_latent_vecs))))\n",
    "umap_scatter_gaussian(latent_data, latent_labels, shift, hidden_dim, learn=True, save=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compute_segregation(latent_data, latent_labels, k=10)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "hidden = 12, latent = 45, snr = 20, noise per = 0"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "latent_data = np.concatenate((train_latent_vecs, test_latent_vecs))\n",
    "latent_labels = np.concatenate((np.zeros(len(train_latent_vecs)), np.ones(len(test_latent_vecs))))\n",
    "umap_scatter_gaussian(latent_data, latent_labels, shift, hidden_dim, learn=True, save=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compute_segregation(latent_data, latent_labels, k=10)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "hidden = 100, latent = 45, snr = 20, noise per = 0\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "latent_data = np.concatenate((train_latent_vecs, test_latent_vecs))\n",
    "latent_labels = np.concatenate((np.zeros(len(train_latent_vecs)), np.ones(len(test_latent_vecs))))\n",
    "umap_scatter_gaussian(latent_data, latent_labels, shift, hidden_dim, learn=True, save=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compute_segregation(latent_data, latent_labels, k=10)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Results for single-cell RNA data:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def single_cell_domain_shift_data_loader(n_samples, n_features, batch_size, target):\n",
    "    # read and process data:\n",
    "    adata = pd.read_csv('data/batch_effect/dataset4/myData_pancreatic_5batches.txt', sep='\\t', header=0, index_col=0)\n",
    "    adata = sc.AnnData(np.transpose(adata))\n",
    "    sample_adata = pd.read_csv('data/batch_effect/dataset4/mySample_pancreatic_5batches.txt', header=0, index_col=0,\n",
    "                               sep='\\t')\n",
    "    adata.obs['cell_type'] = sample_adata.loc[adata.obs_names, ['celltype']]\n",
    "    adata.obs['batch'] = sample_adata.loc[adata.obs_names, ['batchlb']]\n",
    "    sc.pp.filter_cells(adata, min_genes=300)\n",
    "    sc.pp.filter_genes(adata, min_cells=10)\n",
    "    sc.pp.log1p(adata)\n",
    "    sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)\n",
    "    sc.pp.filter_genes_dispersion(data=adata, n_top_genes=n_features, min_mean=0.0125, max_mean=3, min_disp=0.5)\n",
    "    data = adata.X\n",
    "    cell_info = adata.obs\n",
    "    data = stats.zscore(data, axis=0)\n",
    "\n",
    "    # separate data to batches (domains):\n",
    "    batch_1_idx = np.where(np.array(cell_info.batch) == 'Baron_b1')[0]\n",
    "    batch_2_idx = np.where(np.array(cell_info.batch) == 'Mutaro_b2')[0]\n",
    "    batch_3_idx = np.where(np.array(cell_info.batch) == 'Segerstolpe_b3')[0]\n",
    "    batch_4_idx = np.where(np.array(cell_info.batch) == 'Wang_b4')[0]\n",
    "    batch_5_idx = np.where(np.array(cell_info.batch) == 'Xin_b5')[0]\n",
    "    batch_1 = data[batch_1_idx]\n",
    "    batch_2 = data[batch_2_idx]\n",
    "    batch_3 = data[batch_3_idx]\n",
    "    batch_4 = data[batch_4_idx]\n",
    "    batch_5 = data[batch_5_idx]\n",
    "\n",
    "    # configure the source dataset:\n",
    "    batch_1_train = batch_1[: n_samples]\n",
    "    batch_1_test = batch_1[n_samples:]\n",
    "    train_set = TensorDataset(torch.Tensor(batch_1_train))\n",
    "\n",
    "    # configure the target dataset:\n",
    "    if target == 'Mutaro_b2':\n",
    "        test_set = TensorDataset(torch.Tensor(batch_2))\n",
    "    elif target == 'Segerstolpe_b3':\n",
    "        test_set = TensorDataset(torch.Tensor(batch_3))\n",
    "    elif target == 'Wang_b4':\n",
    "        test_set = TensorDataset(torch.Tensor(batch_4))\n",
    "    elif target == 'Xin_b5':\n",
    "        test_set = TensorDataset(torch.Tensor(batch_5))\n",
    "    else:\n",
    "        raise ValueError(\"'target' not implemented\")\n",
    "\n",
    "    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
    "    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last=True)\n",
    "    return train_loader, test_loader, batch_1_test, batch_1, batch_2, batch_3, batch_4, batch_5\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "train_loader, test_loader, batch_1_test, batch_1, batch_2, batch_3, batch_4, batch_5 = single_cell_domain_shift_data_loader(n_samples=5000, n_features=1000, batch_size=10, target='Mutaro_b2')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def cells_umap_scatter(X, y, hidden_dim, learn=False, save=False):\n",
    "    embedding = umap.UMAP(n_components=2, random_state=42, metric='cosine').fit_transform(X)\n",
    "    batch1 = embedding[np.where(y == 1)[0]]\n",
    "    batch2 = embedding[np.where(y == 2)[0]]\n",
    "    batch3 = embedding[np.where(y == 3)[0]]\n",
    "    batch4 = embedding[np.where(y == 4)[0]]\n",
    "    batch5 = embedding[np.where(y == 5)[0]]\n",
    "    if learn:\n",
    "        plt.figure(figsize=(8, 8), dpi=300)\n",
    "        segregation = compute_segregation(X, y)\n",
    "        plt.text(0.5, 1.03, f'UMAP for hidden layer size = {hidden_dim}', horizontalalignment='center', fontsize=22, transform=plt.gca().transAxes)\n",
    "        plt.text(0.5, -0.05, f'KNN-DAT = {round(segregation, 2)}', horizontalalignment='center', fontsize=22, transform=plt.gca().transAxes)\n",
    "    else:\n",
    "        plt.figure(figsize=(12, 8), dpi=300)\n",
    "        #plt.title(f'UMAP of single-cell RNA data', fontsize=22)\n",
    "    plt.scatter(batch1[:, 0], batch1[:, 1], color='blue', marker='o', s=1, alpha=0.3, label='source (Baron_b1)')\n",
    "    plt.scatter(batch2[:, 0], batch2[:, 1], color='orange', marker='o', s=1, alpha=0.3, label='target (Mutaro_b2)')\n",
    "    plt.scatter(batch3[:, 0], batch3[:, 1], color='green', marker='o', s=1, alpha=0.3, label='target (Segerstolpe_b3)')\n",
    "    plt.scatter(batch4[:, 0], batch4[:, 1], color='red', marker='o', s=1, alpha=0.3, label='target (Wang_b4)')\n",
    "    plt.scatter(batch5[:, 0], batch5[:, 1], color='purple', marker='o', s=1, alpha=0.3, label='target (Xin_b5)')\n",
    "    if learn:\n",
    "        plt.legend(loc='upper right', fontsize=10)\n",
    "    else:\n",
    "        plt.legend(loc='upper right', fontsize=16)\n",
    "    plt.grid(False)\n",
    "    plt.tick_params(left=False, right=False, labelleft=False, labelbottom=False, bottom=False)\n",
    "    if save:\n",
    "        if learn:\n",
    "            plt.savefig(f'images/domain_shift/cells_umap_after_learning_hidden_dim_{hidden_dim}.png', dpi=300, bbox_inches='tight')\n",
    "        else:\n",
    "            plt.savefig(f'images/domain_shift/cells_umap_before_learning.png', dpi=300, bbox_inches='tight')\n",
    "    plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "data = np.concatenate((batch_1, batch_2, batch_3, batch_4, batch_5))\n",
    "labels = np.concatenate((np.ones(len(batch_1)), 2 * np.ones(len(batch_2)), 3 * np.ones(len(batch_3)), 4 * np.ones(len(batch_4)), 5 * np.ones(len(batch_5))))\n",
    "cells_umap_scatter(data, labels, hidden_dim, learn=False, save=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compute_segregation(data, labels, k=10)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Train model if needed:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "config = OmegaConf.load('config.yaml')\n",
    "seed = 1\n",
    "batch_size = 128\n",
    "lr = 0.001\n",
    "opt = 'Adam'\n",
    "epochs = 1000\n",
    "n_samples = 5000\n",
    "n_features = 1000\n",
    "scenario = 'domain_shift'\n",
    "cuda = config.cuda\n",
    "\n",
    "latent_dim = 300\n",
    "hidden_dim = 450\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "train_loader, test_loader, train_val_data, batch_1, batch_2, batch_3, batch_4, batch_5 = single_cell_domain_shift_data_loader(n_samples, n_features, batch_size, target='Mutaro_b2')\n",
    "test_val_data = [torch.tensor(batch_2.astype('float32')), torch.tensor(batch_3.astype('float32')), torch.tensor(batch_4.astype('float32')), torch.tensor(batch_5.astype('float32'))]\n",
    "\n",
    "torch.save(torch.tensor(train_val_data.astype('float32')), f\"domain_shift_models/cells_train_val_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")\n",
    "torch.save(test_val_data, f\"domain_shift_models/cells_test_val_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")\n",
    "\n",
    "print(f'num train={len(train_loader.dataset)}')\n",
    "print(f'hidden_dim={hidden_dim}')\n",
    "print(f'latent_dim={latent_dim}')\n",
    "print(f'seed={seed}')\n",
    "\n",
    "model = MLPAE(input_dim=train_loader.dataset[0][0].shape[0], latent_dim=latent_dim, hidden_dim=hidden_dim,\n",
    "               n_hidden_layers=0, final_activation=nn.Identity())\n",
    "\n",
    "total_params = sum(param.numel() for param in model.parameters())\n",
    "print(f'total_params={total_params}')\n",
    "\n",
    "for epoch in tqdm(range(epochs)):\n",
    "    train_loss, test_loss, _, _ = train_and_eval_per_epoch(train_loader, test_loader, model, opt, lr)\n",
    "    if (epoch + 1) % 50 == 0:\n",
    "        print(f'Epoch: {epoch + 1}, Train loss: {train_loss}, test loss: {test_loss}')\n",
    "\n",
    "torch.save(model.state_dict(), f\"domain_shift_models/cells_model_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Prepare latent vectors:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def calc_latent_vecs(data, model):\n",
    "    latent_vecs = np.array([])\n",
    "    with torch.no_grad():\n",
    "        model.eval()\n",
    "        for sample in data:\n",
    "            _, latent_vec = model(sample)\n",
    "            if len(latent_vecs) == 0:\n",
    "                latent_vecs = latent_vec.numpy()\n",
    "            else:\n",
    "                latent_vecs = np.vstack((latent_vecs, latent_vec.numpy()))\n",
    "    return latent_vecs"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "hidden_dim = 450\n",
    "latent_dim = 300\n",
    "train_samples = torch.load(f\"domain_shift_models/cells_train_val_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")\n",
    "test_samples = torch.load(f\"domain_shift_models/cells_test_val_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\")\n",
    "batch_2, batch_3, batch_4, batch_5 = test_samples\n",
    "\n",
    "model = MLPAE(input_dim=train_samples[0].shape[0], latent_dim=latent_dim, hidden_dim=hidden_dim, n_hidden_layers=0, final_activation=nn.Identity())\n",
    "model.load_state_dict(torch.load(f\"domain_shift_models/cells_model_latent_dim_{latent_dim}_hidden_dim_{hidden_dim}.pt\"))\n",
    "\n",
    "train_latent_vecs = calc_latent_vecs(train_samples, model)\n",
    "batch_2_latent_vecs = calc_latent_vecs(batch_2, model)\n",
    "batch_3_latent_vecs = calc_latent_vecs(batch_3, model)\n",
    "batch_4_latent_vecs = calc_latent_vecs(batch_4, model)\n",
    "batch_5_latent_vecs = calc_latent_vecs(batch_5, model)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "hidden dim = 10, latent dim = 300:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "latent_data = np.concatenate((train_latent_vecs, batch_2_latent_vecs, batch_3_latent_vecs, batch_4_latent_vecs, batch_5_latent_vecs))\n",
    "labels = np.concatenate((np.ones(len(train_latent_vecs)), 2 * np.ones(len(batch_2_latent_vecs)), 3 * np.ones(len(batch_3_latent_vecs)), 4 * np.ones(len(batch_4_latent_vecs)), 5 * np.ones(len(batch_5_latent_vecs))))\n",
    "cells_umap_scatter(latent_data, labels, hidden_dim, learn=True, save=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compute_segregation(latent_data, labels, k=10)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "hidden dim = 50, latent dim = 300:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "latent_data = np.concatenate((train_latent_vecs, batch_2_latent_vecs, batch_3_latent_vecs, batch_4_latent_vecs, batch_5_latent_vecs))\n",
    "labels = np.concatenate((np.ones(len(train_latent_vecs)), 2 * np.ones(len(batch_2_latent_vecs)), 3 * np.ones(len(batch_3_latent_vecs)), 4 * np.ones(len(batch_4_latent_vecs)), 5 * np.ones(len(batch_5_latent_vecs))))\n",
    "cells_umap_scatter(latent_data, labels, hidden_dim, learn=True, save=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compute_segregation(latent_data, labels, k=10)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "hidden dim = 450, latent dim = 300:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "latent_data = np.concatenate((train_latent_vecs, batch_2_latent_vecs, batch_3_latent_vecs, batch_4_latent_vecs, batch_5_latent_vecs))\n",
    "labels = np.concatenate((np.ones(len(train_latent_vecs)), 2 * np.ones(len(batch_2_latent_vecs)), 3 * np.ones(len(batch_3_latent_vecs)), 4 * np.ones(len(batch_4_latent_vecs)), 5 * np.ones(len(batch_5_latent_vecs))))\n",
    "cells_umap_scatter(latent_data, labels, hidden_dim, learn=True, save=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compute_segregation(latent_data, labels, k=10)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "hidden dim = 2000, latent dim = 300:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "latent_data = np.concatenate((train_latent_vecs, batch_2_latent_vecs, batch_3_latent_vecs, batch_4_latent_vecs, batch_5_latent_vecs))\n",
    "labels = np.concatenate((np.ones(len(train_latent_vecs)), 2 * np.ones(len(batch_2_latent_vecs)), 3 * np.ones(len(batch_3_latent_vecs)), 4 * np.ones(len(batch_4_latent_vecs)), 5 * np.ones(len(batch_5_latent_vecs))))\n",
    "cells_umap_scatter(latent_data, labels, hidden_dim, learn=True, save=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "compute_segregation(latent_data, labels, k=10)"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
