#initializing datasets.

import pandas as pd
import numpy as np
import scanpy
from collections import Counter

from numpy import *
import matplotlib
import keras
from keras.datasets import mnist
from sklearn.decomposition import PCA


import importlib
import embedding as embed
embed = importlib.reload(embed)

import metric as metric
metric=importlib.reload(metric)


import warnings
warnings.filterwarnings('ignore')

import sys, os

#colormap
clist=['#9e0142','#d53e4f','#f46d43','#fdae61','#fee08b','#e6f598','#abdda4','#66c2a5','#3288bd','#5e4fa2']
rgblist=[tuple(int(clist[j].lstrip('#')[i:i+2], 16)/255 for i in (0, 2, 4)) for j in range(10)]
my_cmap = matplotlib.colors.ListedColormap(rgblist)


clist1=['#003f5c','#444e86','#955196','#dd5182','#ff6e54','#ffa600']
rgblist1=[tuple(int(clist1[j].lstrip('#')[i:i+2], 16)/255 for i in (0, 2, 4)) for j in range(6)]
my_cmap1 = matplotlib.colors.ListedColormap(rgblist1)


#dataset names
str1=['npdata-tasic.npy','npdata-klein.npy','npdata-treg.npy','npdata-mnist.npy','npdata-veronica.npy','npdata-fmnist.npy','npdata-zheng-4743.npy','npdata-noisy-waveform.npy']
str2=['iden-tasic.npy','iden-klein.npy','iden-treg.npy','iden-mnist.npy','iden-veronica.npy','iden-fmnist.npy','iden-zheng-4743.npy','iden-noisy-waveform.npy']





#npdata is d times n and iden is a n length list.



def parse_h5ad(anndf, cluster_label_obs): 
    """
    Return the data, cluster_sizes, and cluster labels from an AnnData object.
    :param anndf: AnnData object
    :param cluster_label_obs: name of the cluster label in the AnnData object
    """
    data = anndf.to_df()
    cluster_labels = anndf.obs[cluster_label_obs].values

    combined = [(cluster_labels.codes[i], data.iloc[i]) for i in range(len(data))]
    combined.sort(key=lambda x: x[0])

    data = np.array([x[1] for x in combined])

    cluster_sizes = [0] * len(cluster_labels.categories)
    for i in range(len(combined)): 
        cluster_sizes[combined[i][0]] += 1

    cluster_labels = np.array([x[0] for x in combined])

    return data, cluster_sizes, cluster_labels


def call_ALM():
    X,label=parse_ALM_VISP('ALM',0)
    return X,label 


def call_VISP():

    X,label=parse_ALM_VISP('VISP',0)
    return X,label 

def parse_ALM_VISP(name,large=1):

    chunksize=1000000
    df1=pd.read_csv(datapath+name+'/data-exon.csv',chunksize=chunksize, iterator=True)
    df_exon = pd.concat(df1, ignore_index=True)
    data=df_exon.to_numpy()
    data=data[:,1:]
    data=data.T
    print(data.shape)
    print(data[0,0])

    if(name=='ALM'):
        df_clusters=pd.read_csv(datapath+name+'/Labels.csv',encoding='unicode_escape')
    else:
        df_clusters=pd.read_csv(datapath+name+'/Labels.csv')
        
    label_st=df_clusters.loc[:, 'cluster'].to_numpy()

    hashmap={}

    t=0
    for i in set(label_st):
        hashmap[i]=t
        t=t+1

    label=[]
    for i in range(len(label_st)):
        label.append(hashmap[label_st[i]])

    label=np.array(label)


    #Cleanup.
    n=data.shape[0]
    if(large==1):
        subset=[] 
        for i in range(n):
            if(Counter(label)[label[i]]>sqrt(n)):
                subset.append(i)

        #print(Counter(label1))

        data=data[subset,:]
        label=label[subset]
    

    X=np.log2(data+1)
    print(X.shape,len(set(label)))

    return X,label



def call_Zheng():
    sce = scanpy.read_h5ad(datapath+'sce_full_Zhengmix8eq.h5ad')
    data, cs, labels = parse_h5ad(sce, 'phenoid')
    X=data
    X=log2(X+1)
    label=labels

    return X,label




def call_Medicine():
#Load the medicine dataset
    
    trm_counts = pd.read_csv(datapath+'/Tcell-medicine/data.tsv', sep='\t', header=None)
    trm_counts = trm_counts.T


    arr = np.array(trm_counts)
    data=arr[1:,1:]
    data = data.astype(float)
    X = np.log2(data+1) 


    cluster_metadata = pd.read_csv(datapath+'/Tcell-medicine/Labels.txt', sep='\t')
    cluster_metadata_arr = np.array(cluster_metadata)
    label1 = cluster_metadata_arr[1:,3]

    x=set(label1)
    hashmap={}
    t=0
    for x1 in x:
        hashmap[x1]=t
        t=t+1

    label=[]
    for i in range(X.shape[0]):
        t1=label1[i]
        label.append(hashmap[t1])


    return X,label




def call_USPS():
    path='usps_path'
    import h5py
    with h5py.File(path, 'r') as hf:
            train = hf.get('train')
            X_tr = train.get('data')[:]
            y_tr = train.get('target')[:]
            test = hf.get('test')
            X_te = test.get('data')[:]
            y_te = test.get('target')[:]
    
    return X_tr,y_tr



def SCRNA_benchmark(name,large=0):

    str=datapath
    pathname=str+name+'/'

    print(pathname)

    chunksize=1000000
    df1=pd.read_csv(datapath+name+'/data.csv',chunksize=chunksize, iterator=True)
    df_exon = pd.concat(df1, ignore_index=True)
    data=df_exon.to_numpy()
    data=data[:,1:]
    print(data.shape)

    n=data.shape[0]

    df_clusters=pd.read_csv(datapath+name+'/Labels.csv',encoding='unicode_escape')

    label=df_clusters.to_numpy()
    print(label.shape)

    if(name=='AMB'):
        label=label[:,2]
    
        if(large==1):
            subset=[]
            for i in range(n):
                if(Counter(label)[label[i]]>sqrt(n)):
                    subset.append(i)

            label1=np.array(label)[subset]
            print(Counter(label1))

            data=data[subset,:]

            label=np.array(label)[subset]


    #print(len(set(label)))

    return data,label




def single_cell_more(name,dimension,kchoice=20):

    data,label=SCRNA_benchmark(name,dimension)
    X=data
    X=log2(X+1)


    label=np.array(label)
    if(dimension!=0):
        pca = PCA(n_components=dimension, svd_solver='randomized')
        PX=pca.fit_transform(X)
    else:
        PX=X

    print(PX.shape)

    metric.KNN_graph_acc(PX,kchoice,0,label)

    edge_list,vlist=embed.dir_KNN_graph(PX,kchoice,0)

    return edge_list,vlist,PX,label

    




def initiate(**kwargs):

    

    gc=int(kwargs.get('fix_ch',-1))

    if(gc==-1):
        choice=int(input("Enter 0 for Tasic, 1 for Klein and 2 for Treg, 3 for MNIST, 4 for Veronica, 5 for F-MNIST and 6 for Zheng-PBMC and 7 for noisy waveform"))
    else:
        choice=gc

    
    #choice=int(input("Enter 0 for Tasic, 1 for Klein and 2 for Treg, 3 for MNIST, 4 for Veronica, 5 for F-MNIST and 6 for Zheng-PBMC and 7 for noisy waveform"))

    print("Chosen dataset is",str1[choice],"\n")

    npdata=[]
    iden=[]

    npdata=np.load(datapath+str1[choice])
    npdata=npdata.astype('float')

    iden=np.load(datapath+str2[choice])
    iden=iden.astype('int32')

    #Remember for waveform.
    if(choice==7):
        npdata=npdata[0:21,:]

    print(npdata.shape,iden.shape)

    #log normalize
    ch=kwargs.get('lnorm', 1)
    if(ch==1):
            print("log transform")
            npdata=log2(npdata+1)
    
    return npdata,iden

def call_Zheng():
    sce = scanpy.read_h5ad(datapath+'sce_full_Zhengmix8eq.h5ad')
    data, cs, labels = parse_h5ad(sce, 'phenoid')
    X=data
    X=log2(X+1)
    label=labels

    return X,label

def call_FMNIST():
    X,label=initiate(fix_ch=5,lnorm=0)
    X=X.T
    return X,label

def call_MNIST():


    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    X = x_train.reshape(x_train.shape[0], -1)
    label = y_train

    return X,label

def call_Treg():
    X,label=initiate(fix_ch=2,lnorm=1)
    X=X.T
    return X,label
