{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(os.path.join(sys.path[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import colorsys\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def generate_shades_of_color(tone, k):\n",
    "    if tone == 'red':\n",
    "        base_hue = 0.0\n",
    "    elif tone == 'green':\n",
    "        base_hue = 120.0\n",
    "    elif tone == 'blue':\n",
    "        base_hue = 240.0\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported tone. Please choose from 'red', 'green', or 'blue'.\")\n",
    "\n",
    "    # Fix saturation\n",
    "    \n",
    "\n",
    "    # Generate k shades of the given color\n",
    "    colors = []\n",
    "    for i in range(k):\n",
    "        # Vary the lightness from dark to light\n",
    "        lightness = 0.1+ 0.5*(i+1)/k  # Lightness ranges from 1/k to 1\n",
    "        saturation = 1\n",
    "        rgb = colorsys.hls_to_rgb((base_hue+ 60*((i+1)/ k))/ 360.0, lightness, saturation)\n",
    "        colors.append(rgb)\n",
    "\n",
    "    return colors\n",
    "\n",
    "def plot_colors(colors):\n",
    "    fig, ax = plt.subplots()\n",
    "    for i, color in enumerate(colors):\n",
    "        ax.add_patch(plt.Rectangle((i, 0), 1, 1, color=color))\n",
    "    ax.set_xlim(0, len(colors))\n",
    "    ax.set_ylim(0, 1)\n",
    "    ax.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "# Generate shades of red\n",
    "red_shades = generate_shades_of_color('red', 6)\n",
    "plot_colors(red_shades[0:-1])\n",
    "\n",
    "# Generate shades of green\n",
    "green_shades = generate_shades_of_color('green', 4)\n",
    "plot_colors(green_shades[0:-1])\n",
    "\n",
    "# Generate shades of blue\n",
    "blue_shades = generate_shades_of_color('blue', 3)\n",
    "plot_colors(blue_shades[0:-1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Import packages\n",
    "#imports\n",
    "import time\n",
    "import pandas as pd\n",
    "from collections import Counter\n",
    "import sknetwork\n",
    "\n",
    "\n",
    "from sknetwork.ranking import PageRank\n",
    "from sknetwork.ranking import Betweenness\n",
    "from sknetwork.ranking import Closeness\n",
    "\n",
    "import umap as umap\n",
    "\n",
    "from numba.typed import List\n",
    "import warnings\n",
    "from numba import njit\n",
    "import pynndescent\n",
    "import numpy as np\n",
    "from sklearn.cluster import SpectralClustering\n",
    "import numpy as np\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import operator\n",
    "from sklearn.utils.extmath import randomized_svd\n",
    "from random import randint\n",
    "from sklearn.utils.extmath import randomized_svd\n",
    "from sklearn.decomposition import TruncatedSVD\n",
    "import scipy\n",
    "from umap.umap_ import *\n",
    "import math\n",
    "from random import randint\n",
    "import keras\n",
    "from keras.datasets import mnist\n",
    "from keras.datasets import cifar10\n",
    "from keras.datasets import cifar100\n",
    "from keras.datasets import fashion_mnist\n",
    "import scanpy\n",
    "from sklearn.metrics.cluster import normalized_mutual_info_score\n",
    "from sklearn.metrics import adjusted_mutual_info_score, roc_auc_score\n",
    "from sklearn.metrics.cluster import adjusted_rand_score\n",
    "from sklearn.metrics.cluster import adjusted_mutual_info_score\n",
    "\n",
    "import community as community_louvain\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics import fowlkes_mallows_score\n",
    "\n",
    "from plots import *\n",
    "import scipy\n",
    "\n",
    "import igraph \n",
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import importlib\n",
    "\n",
    "from plots import *\n",
    "\n",
    "\n",
    "import benchmark as bmarks\n",
    "import metric as met \n",
    "import FlowRank as algo\n",
    "import datasets as dsets\n",
    "import simulation as simm\n",
    "import embedding as embed\n",
    "import newflow as newalgo\n",
    "\n",
    "met=importlib.reload(met)\n",
    "newalgo=importlib.reload(newalgo)\n",
    "algo = importlib.reload(algo)\n",
    "dsets = importlib.reload(dsets)\n",
    "simm = importlib.reload(simm)\n",
    "embed=importlib.reload(embed)\n",
    "bmarks=importlib.reload(bmarks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "functions = [func for func in dir(bmarks) if callable(getattr(bmarks, func))]\n",
    "print(functions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#dataset details\n",
    "datapath='add datapath' \n",
    "#'Baron_Human', 'Baron_Mouse', 'Muraro', 'Segerstolpe', 'Xin', 'Zhengmix8eq', 'Tcell-medicine', \n",
    "datanames = ['Baron_Human', 'Baron_Mouse', 'Muraro', 'Segerstolpe','Xin', 'Zhengmix8eq','Tcell-medicine','ALM', 'AMB', 'TM', 'VISP']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_this(subset,label,edge_list):\n",
    "\n",
    "    n=len(label)\n",
    "    CC_count0,new_clusters0=calc_concentration([i for i in range(n)],label,edge_list)\n",
    "\n",
    "    CC_count,new_clusters=calc_concentration(subset,label,edge_list)\n",
    "\n",
    "    ell=len(set(label))\n",
    "\n",
    "    for i in range(ell):\n",
    "        vv=len(new_clusters[i])\n",
    "        vv0=len(new_clusters0[i])\n",
    "        if(vv==0):\n",
    "            print(i,'n/a','%.2f'%(CC_count0[i]/vv0),end=' ')\n",
    "        else:\n",
    "            print(i,\"CC= \",'%.2f'%(CC_count[i]/vv),'%.2f'%(CC_count0[i]/vv0), end=' ')\n",
    "\n",
    "    print('\\n')\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_concentration(subset,label,edge_list):\n",
    "\n",
    "    ell=len(set(label))\n",
    "    label1=[]\n",
    "    for i in subset:\n",
    "        label1.append(label[i])\n",
    "\n",
    "    old_clusters=[[] for _ in range(ell)]\n",
    "\n",
    "    new_clusters=[[] for _ in range(ell)]\n",
    "\n",
    "    for i in range(len(label)):\n",
    "        old_clusters[label[i]].append(i)\n",
    "\n",
    "    t=0\n",
    "    for i in subset:\n",
    "        new_clusters[label1[t]].append(i)\n",
    "        t=t+1\n",
    "\n",
    "    #for i in range(ell):\n",
    "    #    print(i,len(set(old_clusters[i]).intersection(set(new_clusters[i]) ) ) )\n",
    "\n",
    "    \n",
    "    hashmap={}\n",
    "    \n",
    "    hashmap_ch={}\n",
    "\n",
    "    for i in range(len(label)):\n",
    "        hashmap_ch[i]=0\n",
    "        \n",
    "    for i in subset:\n",
    "        hashmap_ch[i]=1\n",
    "\n",
    "    for i in range(len(label)):\n",
    "        hashmap[i]=label[i]\n",
    "\n",
    "\n",
    "    CC_count=np.zeros((ell))\n",
    "\n",
    "\n",
    "    for (u,v) in edge_list:\n",
    "        if(hashmap_ch[u]==1):\n",
    "\n",
    "            sl=hashmap[u]\n",
    "            se=hashmap[v]\n",
    "            if(sl!=se):\n",
    "                CC_count[sl]=CC_count[sl]-1\n",
    "\n",
    "        if(hashmap_ch[v]==1):\n",
    "            sl=hashmap[u]\n",
    "            se=hashmap[v]\n",
    "            if(sl!=se):\n",
    "                CC_count[se]=CC_count[se]+1\n",
    "\n",
    "\n",
    "    return CC_count,new_clusters\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Single cell datasets.\n",
    "\n",
    "def ranking_scRNA(datanames,algonames,bnames_ranking,bnames_outlier,kchoice=20):\n",
    "\n",
    "    alpha=6\n",
    "\n",
    "    #Setup the colors needed for prints\n",
    "    c1=generate_shades_of_color('red', len(bnames_ranking)+1)[0:-1]\n",
    "    c2=generate_shades_of_color('blue', len(bnames_outlier)+1)[0:-1]\n",
    "    c3=generate_shades_of_color('green',len(algonames)+1)[0:-1]\n",
    "    colors=c1+c3\n",
    "\n",
    "    nums=len(datanames)\n",
    "\n",
    "    fn=[[] for _ in range(nums)]\n",
    "    fp=[[] for _ in range(nums)]\n",
    "    flouvain=[]\n",
    "\n",
    "\n",
    "    ell1=0\n",
    "    for name in datanames:\n",
    "        X = scipy.sparse.load_npz(datapath+name + '/data.npz')\n",
    "        label = np.load(datapath+name+'/labels.npy')\n",
    "        print(name,len(label))\n",
    "        \n",
    "        #Log transform+PCA\n",
    "        X.data = np.log1p(X.data)\n",
    "        print(\"Log transform done\")\n",
    "        pca = TruncatedSVD(n_components=50)\n",
    "        PX = pca.fit_transform(X)\n",
    "        n=PX.shape[0]\n",
    "        walk_len_c1=int(np.log2(n))\n",
    "        print(PX.shape)\n",
    "\n",
    "\n",
    "        #Calculte inital KNN accuracy\n",
    "        met.KNN_graph_acc(PX,kchoice,0,label)\n",
    "\n",
    "        #Get the KNN edgelist\n",
    "        edge_list,vlist=embed.dir_KNN_graph(PX,kchoice,0)\n",
    "        print(len(edge_list))\n",
    "\n",
    "        #calc_concentration([i for i in range(n)],label,edge_list)\n",
    "\n",
    "        #Get the all the benchmark rankings\n",
    "        v_cover_orders=[]\n",
    "        for bname in bnames_ranking:\n",
    "            func = getattr(bmarks, bname)\n",
    "            v_cover_order=func(edge_list,vlist,PX,n)\n",
    "            v_cover_orders.append(v_cover_order)\n",
    "\n",
    "\n",
    "        #Now we use our algorithms.\n",
    "        for algoname in algonames:\n",
    "            func = getattr(algo, algoname)\n",
    "            v_cover_order=func(edge_list,vlist,walk_len_c1,0)\n",
    "            v_cover_orders.append(v_cover_order)\n",
    "\n",
    "            subset1=v_cover_order[0:n//3,1].astype(int)\n",
    "            #calc_concentration(list1,label,edge_list)\n",
    "            calc_this(subset1,label,edge_list)\n",
    "\n",
    "\n",
    "\n",
    "        total_names=bnames_ranking+algonames\n",
    "\n",
    "        #total_names=algonames\n",
    "\n",
    "        NNacc=met.accuracy(v_cover_orders,edge_list,total_names,colors,n,label,name=name)\n",
    "        p_ratio=met.preservstion(v_cover_orders,total_names,colors,n,label,name=name)\n",
    "\n",
    "        #cluster_res=met.louvain_plots(v_cover_orders,edge_list,PX,total_names,n,label,name='plot')\n",
    "        cluster_res=0\n",
    "        fn[ell1].append(NNacc)\n",
    "        fp[ell1].append(p_ratio)\n",
    "        flouvain.append(cluster_res)\n",
    "        ell1=ell1+1\n",
    "        \n",
    "        # print(NNacc,p_ratio)\n",
    "\n",
    "    return fn,fp,flouvain\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "#names of all methods\n",
    "bnames_ranking=['Katz_score', 'pagerank_5','pagerank_85','pagerank_99','cores', 'deg_centrality']\n",
    "algonames=['FLOW_ng','FLOW_ng_prop','FLOW_ng2hopsimple']\n",
    "\n",
    "total_names=bnames_ranking+algonames\n",
    "print(total_names)\n",
    "\n",
    "#dataset names\n",
    "datanames = ['Baron_Human', 'Baron_Mouse', 'Muraro', 'Segerstolpe','Xin', 'Zhengmix8eq','Tcell-medicine','ALM', 'AMB', 'TM', 'VISP']\n",
    "\n",
    "\n",
    "#Running this function with the datasets will obtain all ICEF and preservation ratio plots.\n",
    "\n",
    "ICEF,PRESERVE,LOUVAIN=ranking_scRNA(datanames,algonames,bnames_ranking,'')"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
  },
  "kernelspec": {
   "display_name": "Python 3.10.1 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.1"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
