{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from dataloader import GeneticDataloaders, SynGeneticDataset, GeneticDataloaders1k\n",
    "from torch.utils.data import DataLoader\n",
    "from config import *\n",
    "import numpy as np\n",
    "import sklearn\n",
    "num_samples = 200\n",
    "\n",
    "import torch\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_samples(dataloader,num_samples):\n",
    "    data = []\n",
    "    data_labels = []\n",
    "    for i, (x, y) in enumerate(dataloader):\n",
    "        for d in x:\n",
    "            data.append(d)\n",
    "        for l in y:\n",
    "            data_labels.append(l)\n",
    "        if len(data) >= num_samples:\n",
    "            break\n",
    "\n",
    "    data = torch.stack(data)\n",
    "    data_labels = torch.stack(data_labels)\n",
    "\n",
    "    out_label = np.argmax(data_labels.numpy(), axis = 1) if len(data_labels.shape) > 1 else data_labels.numpy()\n",
    "    return data.numpy()[:num_samples], out_label[:num_samples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "datas = []\n",
    "\n",
    "geneticData = SynGeneticDataset(\"finalruns/Transformer/\")\n",
    "syn_dataloader = DataLoader(geneticData, batch_size=config[\"batch_size\"])\n",
    "syn_data, syn_data_labels = draw_samples(syn_dataloader, num_samples)\n",
    "datas.append(syn_data)\n",
    "\n",
    "geneticData = SynGeneticDataset(\"finalruns/UnetMLP/\")\n",
    "syn_dataloader = DataLoader(geneticData, batch_size=config[\"batch_size\"])\n",
    "syn_data, syn_data_labels = draw_samples(syn_dataloader, num_samples)\n",
    "datas.append(syn_data)\n",
    "\n",
    "geneticData = SynGeneticDataset(\"finalruns/Unet/\")\n",
    "syn_dataloader = DataLoader(geneticData, batch_size=config[\"batch_size\"])\n",
    "syn_data, syn_data_labels = draw_samples(syn_dataloader, num_samples)\n",
    "datas.append(syn_data)\n",
    "\n",
    "# geneticData = SynGeneticDataset(\"finalruns/Baseline/\")\n",
    "# syn_dataloader = DataLoader(geneticData, batch_size=config[\"batch_size\"])\n",
    "# syn_data, syn_data_labels = draw_samples(syn_dataloader, num_samples)\n",
    "# datas.append(syn_data)\n",
    "\n",
    "geneticData = SynGeneticDataset(\"finalruns/UnetCombined/\")\n",
    "syn_dataloader = DataLoader(geneticData, batch_size=config[\"batch_size\"])\n",
    "syn_data, syn_data_labels = draw_samples(syn_dataloader, num_samples)\n",
    "datas.append(syn_data)\n",
    "\n",
    "# geneticData = SynGeneticDataset(\"newgeneration/\")\n",
    "# syn_dataloader = DataLoader(geneticData, batch_size=config[\"batch_size\"])\n",
    "# syn_data, syn_data_labels = draw_samples(syn_dataloader, num_samples)\n",
    "# datas.append(syn_data)\n",
    "\n",
    "#train_dataset = GeneticDataset1k(train = True)\n",
    "#test_dataset = GeneticDataset1k(train = False)\n",
    "train_dataloader,test_dataloader = GeneticDataloaders(config[\"batch_size\"]) \n",
    "syn_data, syn_data_labels = draw_samples(train_dataloader, num_samples)\n",
    "datas.append(syn_data)\n",
    "syn_data, syn_data_labels = draw_samples(test_dataloader, num_samples)\n",
    "datas.append(syn_data)\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# datasfull =[]\n",
    "# num_samplesfull = 9000\n",
    "\n",
    "# train_dataloader,test_dataloader = GeneticDataloaders(config[\"batch_size\"], True) \n",
    "# syn_data, syn_data_labels = draw_samples(train_dataloader, num_samplesfull)\n",
    "# datasfull.append(syn_data)\n",
    "\n",
    "# geneticData = SynGeneticDataset(\"finalruns/UnetCombined/\")\n",
    "# syn_dataloader = DataLoader(geneticData, batch_size=config[\"batch_size\"])\n",
    "# syn_data, syn_data_labels = draw_samples(syn_dataloader, num_samplesfull)\n",
    "# datasfull.append(syn_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "combined_data = np.concatenate(datas, axis=0)\n",
    "combined_data = combined_data.reshape(combined_data.shape[0], -1)\n",
    "print(combined_data.shape)\n",
    "\n",
    "\n",
    "combined_labels = np.concatenate([np.zeros(len(datas[i]))+i for i in range(len(datas))])\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Embedd the data samples with different models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from config import *\n",
    "\n",
    "model = torch.load(\"finalruns/UnetMLP/model.pt\")\n",
    "\n",
    "\n",
    "model.eval()\n",
    "current_t = torch.tensor([max_steps-1] , device=device)\n",
    "y = torch.tensor([num_classes] , device=device)\n",
    "embeddings = []\n",
    "with torch.no_grad():\n",
    "    for data in combined_data:\n",
    "        \n",
    "        data = torch.tensor(data).to(device)\n",
    "        data = data.unsqueeze(0)\n",
    "        data = data.reshape(1,8,18432)\n",
    "        # print(data.shape)\n",
    "        # print(current_t.shape)\n",
    "        # print(y.shape)\n",
    "        #pred_eps = model(data, current_t, y = y)\n",
    "        output, bottleneck = model(data,t = current_t,y = y, output_bottleneck = True)\n",
    "        embeddings.append(bottleneck.squeeze().cpu().detach().numpy())\n",
    "\n",
    "embeddings = np.array(embeddings)\n",
    "print(embeddings.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.manifold import TSNE\n",
    "\n",
    "reducer = TSNE(n_components=2)\n",
    "embedding = reducer.fit_transform(embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = plt.get_cmap('nipy_spectral')(np.linspace(0, 1, len(datas)))\n",
    "labels =  [\"Transformer\", \"MLP\", \"CNN\", \"MLP + CNN\" ,\"Train Data\", \"Test Data\", ]\n",
    "\n",
    "for i in range(len(datas)):\n",
    "    plt_data = embedding[i*num_samples: (i+1)*num_samples]\n",
    "    plt.scatter(plt_data[:, 0], plt_data[:, 1], color = colors[i], label = labels[i], s = 5)\n",
    "plt.legend()\n",
    "plt.savefig(\"umap_mlp_euclidean.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.load(\"finalruns/Unet/model.pt\")\n",
    "\n",
    "model.eval()\n",
    "current_t = torch.tensor([max_steps-1] , device=device)\n",
    "y = torch.tensor([num_classes] , device=device)\n",
    "embeddingscnn = []\n",
    "with torch.no_grad():\n",
    "    for data in combined_data:\n",
    "        \n",
    "        data = torch.tensor(data).to(device)\n",
    "        data = data.unsqueeze(0)\n",
    "        data = data.reshape(1,8,18432)\n",
    "        # print(data.shape)\n",
    "        # print(current_t.shape)\n",
    "        # print(y.shape)\n",
    "        #pred_eps = model(data, current_t, y = y)\n",
    "        output, bottleneck = model(data,current_t,y = y, output_bottleneck = True)\n",
    "        #print(bottleneck.shape)\n",
    "        bottleneck = torch.mean(bottleneck, dim = 2)\n",
    "        embeddingscnn.append(bottleneck.squeeze().cpu().detach().numpy())\n",
    "\n",
    "embeddingscnn = np.array(embeddingscnn)\n",
    "print(embeddingscnn.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = TSNE(n_components=2)\n",
    "embedding = reducer.fit_transform(embeddingscnn)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = plt.get_cmap('nipy_spectral')(np.linspace(0, 1, len(datas)))\n",
    "labels =  [\"Transformer\", \"MLP\", \"CNN\", \"MLP + CNN\" ,\"Train Data\", \"Test Data\", ]\n",
    "\n",
    "for i in range(len(datas)):\n",
    "    plt_data = embedding[i*num_samples: (i+1)*num_samples]\n",
    "    plt.scatter(plt_data[:, 0], plt_data[:, 1], color = colors[i], label = labels[i], s = 5)\n",
    "plt.legend()\n",
    "plt.savefig(\"tsne_unet_euclidean.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.load(\"finalruns/UnetCombinednew/model.pt\")\n",
    "\n",
    "model.eval()\n",
    "current_t = torch.tensor([max_steps-1] , device=device)\n",
    "y = torch.tensor([num_classes] , device=device)\n",
    "embeddingscnn = []\n",
    "with torch.no_grad():\n",
    "    for data in combined_data:\n",
    "        \n",
    "        data = torch.tensor(data).to(device)\n",
    "        data = data.unsqueeze(0)\n",
    "        data = data.reshape(1,8,18432)\n",
    "        output, bottleneck = model(data,current_t,y = y, output_bottleneck = True)\n",
    "        embeddingscnn.append(bottleneck.squeeze().cpu().detach().numpy())\n",
    "\n",
    "embeddingscnn = np.array(embeddingscnn)\n",
    "print(embeddingscnn.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reducer = TSNE(n_components=2)\n",
    "embedding = reducer.fit_transform(embeddingscnn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = plt.get_cmap('nipy_spectral')(np.linspace(0, 1, len(datas)))\n",
    "labels =  [\"Transformer\", \"MLP\", \"CNN\", \"MLP + CNN\" ,\"Train Data\", \"Test Data\", ]\n",
    "\n",
    "for i in range(len(datas)):\n",
    "    plt_data = embedding[i*num_samples: (i+1)*num_samples]\n",
    "    plt.scatter(plt_data[:, 0], plt_data[:, 1], color = colors[i], label = labels[i], s = 5)\n",
    "plt.legend()\n",
    "plt.savefig(\"tsne_unetcombined_euclidean.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"computing UMAP\")\n",
    "reducer = TSNE(n_components=2, perplexity=15)\n",
    "embedding = reducer.fit_transform(combined_data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "colors = plt.get_cmap('nipy_spectral')(np.linspace(0, 1, len(datas)))\n",
    "labels =  [\"Transformer\", \"MLP\", \"CNN\", \"MLP + CNN\" ,\"Test Data\", \"Train Data\"]\n",
    "for i in range(len(datas)):\n",
    "    plt_data = embedding[i*num_samples: (i+1)*num_samples]\n",
    "    plt.scatter(plt_data[:, 0], plt_data[:, 1], color = colors[i], label = labels[i], s = 5)\n",
    "plt.legend()\n",
    "plt.savefig(\"tsne.png\")\n",
    "plt.show()\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_checks = 200\n",
    "\n",
    "@torch.compile\n",
    "def distance(x1, x2):\n",
    "    x1 = x1.flatten()\n",
    "    x2 = x2.flatten()\n",
    "    #return np.arccos(np.dot(x1,x2)/(np.linalg.norm(x1)*np.linalg.norm(x2)))/np.pi\n",
    "    return np.mean((x1-x2)**2)\n",
    "\n",
    "def check_diversity_by_closest(data1, data2):\n",
    "\n",
    "    def find_closest(dataset1, dataset2, num = num_checks):\n",
    "        all_min_dist = []\n",
    "        for i,datapoint in enumerate(dataset1):\n",
    "            if i > num:\n",
    "                break\n",
    "            #find closest sample in geneticData\n",
    "            min_dist = 1e50\n",
    "            for sample2 in dataset2:\n",
    "                dist = distance(datapoint, sample2)\n",
    "                #dist = torch.arccos(torch.nn.functional.cosine_similarity(sample,sample2.flatten(), dim = 0)) / np.pi\n",
    "                if dist < min_dist and dist > 1e-3:\n",
    "                    min_dist = dist\n",
    "            all_min_dist.append(min_dist)\n",
    "           # print(min_dist)\n",
    "        \n",
    "        return all_min_dist\n",
    "\n",
    "    AA_ts = find_closest(data1, data2)\n",
    "    AA_st = find_closest(data2, data1)\n",
    "    AA_tt  = find_closest(data1, data1)\n",
    "    AA_ss = find_closest(data2, data2)\n",
    "\n",
    "    AA_truth = 0\n",
    "    AA_syn = 0\n",
    "    for i in range(num_checks):\n",
    "        AA_truth += 1 if AA_ts[i]>AA_tt[i] else 0\n",
    "        AA_syn += 1 if AA_st[i]>AA_ss[i] else 0\n",
    "      #  print(AA_ts[i], AA_st[i], AA_tt[i], AA_ss[i])\n",
    "    \n",
    "    print(f\"AA_truth: {AA_truth/num_checks}, AA_syn: {AA_syn/num_checks} AA_TS {(AA_truth/num_checks + AA_syn/num_checks) /2}\")\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def distance(x1, x2):\n",
    "    x1 = x1.flatten()\n",
    "    x2 = x2.flatten()\n",
    "    return np.arccos(np.dot(x1,x2)/(np.linalg.norm(x1)*np.linalg.norm(x2)))/np.pi\n",
    "    return 1-np.dot(x1,x2)/(np.linalg.norm(x1)*np.linalg.norm(x2))\n",
    "    return np.mean((x1-x2)**2)\n",
    "  #  return np.mean(np.abs(x1-x2))\n",
    "def compute_closest(data1,data2, topn = 10):\n",
    "    closest_neighborss = []\n",
    "    for i,p1 in tqdm(enumerate(data1)):\n",
    "        dists = []\n",
    "        closest_neighbors = []\n",
    "        for p2 in data2:\n",
    "            dist = distance(p1, p2)\n",
    "            closest_neighbors.append(0)\n",
    "            dists.append(dist)\n",
    "      #  print(min_dist)\n",
    "        for a,p1_1 in enumerate(data1):\n",
    "            dist = distance(p1, p1_1)\n",
    "            if a == i:\n",
    "                continue\n",
    "            closest_neighbors.append(1)\n",
    "            dists.append(dist)\n",
    "            # if dist < min_dist and dist > 1e-4:\n",
    "            #     closest_neighbor = 1\n",
    "            #     min_dist = dist\n",
    "        \n",
    "\n",
    "      #  print(min_dist, closest_neighbor)\n",
    "        indx = np.argsort(np.array(dists))\n",
    "        min_dists = np.array(dists)[indx[:topn]]\n",
    "        closest_neighbors = np.array(closest_neighbors)\n",
    "        closest_neighbor = closest_neighbors[indx[:topn]]\n",
    "\n",
    "        print(min_dists)\n",
    "       # print(closest_neighbor)\n",
    "\n",
    "        closest_neighborss.append(np.mean(closest_neighbor))\n",
    "       # print(closest_neighborss)\n",
    "    closest_neighborss = np.array(closest_neighborss)\n",
    "    avg_class = np.mean(closest_neighborss)\n",
    "    return avg_class\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "print(\"Checking diversity for test and train\")\n",
    "print(compute_closest(datas[5], datas[6]))\n",
    "print(\"Checking diversity for test and transformer\")\n",
    "print(compute_closest(datas[5], datas[0]))\n",
    "print(\"Checking diversity for test and mlp\")\n",
    "print(compute_closest(datas[5], datas[1]))\n",
    "print(\"Checking diversity for test and cnn\")\n",
    "print(compute_closest(datas[5], datas[2]))\n",
    "print(\"Checking diversity for test and baseline\")\n",
    "print(compute_closest(datas[5], datas[3]))\n",
    "print(\"Checking diversity for test and mlp+cnn\")\n",
    "print(compute_closest(datas[5], datas[4]))\n",
    "\n",
    "print(\"Checking diversity for mlp and test\")\n",
    "print(compute_closest(datas[1], datas[5]))\n",
    "print(\"Checking diversity for transformer and test\")\n",
    "print(compute_closest(datas[0], datas[5]))\n",
    "print(\"Checking diversity for cnn and test\")\n",
    "print(compute_closest(datas[2], datas[5]))\n",
    "print(\"Checking diversity for baseline and test\")\n",
    "print(compute_closest(datas[3], datas[5]))\n",
    "print(\"Checking diversity for mlp+cnn and test\")\n",
    "print(compute_closest(datas[4], datas[5]))\n",
    "print(\"Checking diversity for train and test\")\n",
    "print(compute_closest(datas[6], datas[5]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "print(\"Checking diversity for train and transformer\")\n",
    "print(compute_closest(datas[6], datas[0]))\n",
    "print(\"Checking diversity for train and mlp\")\n",
    "print(compute_closest(datas[6], datas[1]))\n",
    "print(\"Checking diversity for train and cnn\")\n",
    "print(compute_closest(datas[6], datas[2]))\n",
    "print(\"Checking diversity for train and baseline\")\n",
    "print(compute_closest(datas[6], datas[3]))\n",
    "print(\"Checking diversity for train and mlp+cnn\")\n",
    "print(compute_closest(datas[6], datas[4]))\n",
    "\n",
    "print(\"Checking diversity for mlp and train\")\n",
    "print(compute_closest(datas[1], datas[6]))\n",
    "print(\"Checking diversity for transformer and train\")\n",
    "print(compute_closest(datas[0], datas[6]))\n",
    "print(\"Checking diversity for cnn and train\")\n",
    "print(compute_closest(datas[2], datas[6]))\n",
    "print(\"Checking diversity for baseline and train\")\n",
    "print(compute_closest(datas[3], datas[6]))\n",
    "print(\"Checking diversity for mlp+cnn and train\")\n",
    "print(compute_closest(datas[4], datas[6]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"nnaa for dataset test and transformer\")\n",
    "check_diversity_by_closest(datas[-1], datas[0])\n",
    "\n",
    "print(\"nnaa for dataset test and mlp\")\n",
    "check_diversity_by_closest(datas[-1], datas[1])\n",
    "print(\"nnaa for dataset test and cnn\")\n",
    "check_diversity_by_closest(datas[-1], datas[2])\n",
    "print(\"nnaa for dataset test and mlp+cnn\")\n",
    "check_diversity_by_closest(datas[-1], datas[3])\n",
    "print(\"nnaa for dataset test and train\")\n",
    "check_diversity_by_closest(datas[-1], datas[-2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"nnaa for dataset train and transformer\")\n",
    "check_diversity_by_closest(datas[-2], datas[0])\n",
    "print(\"nnaa for dataset train and mlp\")\n",
    "check_diversity_by_closest(datas[-2], datas[1])\n",
    "print(\"nnaa for dataset train and cnn\")\n",
    "check_diversity_by_closest(datas[-2], datas[2])\n",
    "print(\"nnaa for dataset train and mlp+cnn\")\n",
    "check_diversity_by_closest(datas[-2], datas[3])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sls3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
