import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import scipy
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils import data
from torchvision import transforms
from PIL import Image
import os
from collections import OrderedDict
import matplotlib.pyplot as plt
import torchvision.models as models
# This is for the progress bar.
from tqdm import tqdm
import csv
import torch.utils.data as Data
import os
from sklearn.model_selection import train_test_split
import sys
import numpy as np
import pandas as pd
from sklearn.manifold import TSNE
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
from sklearn import metrics
from glob import glob
import datetime
import torch
import math
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
import random
from scipy.stats import t
from sklearn.metrics import roc_curve, roc_auc_score
# from openTSNE import TSNE
import torchvision
from torch.utils import data
from torchvision import transforms
import numpy as np
from torch.utils.data import Subset
import torch.nn.functional as F
from matplotlib.pyplot import figure
from torch import Tensor
import torchvision.transforms as transforms
import numpy as np
import torchvision.datasets as datasets
import torch
import numpy as np
import torch.utils.data as utils
from torch.utils.data import Sampler, Dataset
from scipy.io import savemat
import numpy as np
import pandas as pd
import scanpy as sc
import os
import argparse


def convert_to_num_label(verb_label):
    label_set = set(verb_label)
    label_dict = dict(zip(list(label_set), range(len(label_set))))
    conveted = np.array([label_dict[i] for i in verb_label])
    return conveted, label_dict

def handle_to_dataset(root, data_name=None, OD=1, rare_types=1, pre=1):
    # path = './GENE DATA/' + data_name
    path = root + '/' + data_name
    if data_name == "Campbell":
        adata = sc.read(path + '/matrix.mtx', cache=True).T  # rare_types:选择最少的多个类
        data_array = adata.X.toarray()
        # data_array = np.log1p(data_array)
        ori_labels = pd.read_csv(path + "/labels.csv", dtype="str")
        labels = ori_labels['10.clust_neurons']
        # print(ori_labels)
    if data_name == "Baron Human":
        adata = sc.read(path + '/matrix.mtx', cache=True)  # rare_types:选择最少的多个类
        data_array = adata.X.toarray()
        # data_array = np.log1p(data_array)
        data_array = data_array.T
        labels = pd.read_csv(path + "/labels.csv", dtype="str", header=None)[1:].squeeze()
    if data_name == "Mouse_retina":
        adata = sc.read(path + '/Gene_Cell.mtx', cache=True).T  # rare_types:选择最少的多个类
        data_array = adata.X.toarray()
        # data_array = np.log1p(data_array)
        labels = pd.read_csv(path + "/Cell_type.tsv", dtype="str", header=None).squeeze()
    if data_name == "PBMC68K":
        adata = sc.read(path + '/matrix.mtx', cache=True).T  # rare_types:选择最少的多个类
        data_array = adata.X.toarray()
        # data_array = np.log1p(data_array)
        labels = pd.read_csv(path + "/PBMC_label.csv", dtype="str")[:]["original_label"]
        # print(labels.shape)
    if data_name == "BA":
        train_data = sc.read(path + '/BFdata.mtx', cache=True)  # rare_types:选择最少的多个类
        test_data = sc.read(path + '/AFdata.mtx', cache=True)  # rare_types:选择最少的多个类
        train_data = train_data.X.toarray()
        print(train_data.shape)
        test_data = test_data.X.toarray()
        print(test_data.shape)
        if pre == 1:
            train_data = np.log1p(train_data)
            test_data = np.log1p(test_data)
        # data_array = np.log1p(data_array.toarray())
        # train_labels=pd.read_csv(path+"/BF_label.csv",dtype="str")[:]["type_mapped"]
        # test_labels=pd.read_csv(path+"/AF_label.csv",dtype="str")[:]["type_mapped"]
        test_labels = pd.read_csv(path + "/AF_label.csv", dtype="str")[:]["type_mapped"]
        test_labels = test_labels.astype(int)
        print(test_labels)
        # train_set,test_set=CustomDataset(np.array(train_data), np.zeros(train_data.shape[0])), CustomDataset(np.array(test_data),test_labels )
        return train_data, test_data, test_labels
    print("data shape: " + str(data_array.shape))
    if pre == 1:
        data_array = np.log1p(data_array)

    label_counts = labels.value_counts()
    print("cell types: " + str(len(label_counts)))
    print(label_counts)
    converted, label_dict = convert_to_num_label(labels)
    print(converted[:10])
    torch.save((data_array, converted, label_dict), f'/mnt/data01/public/aad_data/gene/{data_name}_extracted.tar', pickle_protocol=4)
    # print(converted[:10])
    
    # labels = pd.read_csv(path + "/PBMC_label.csv", dtype="str")
    # print(labels)
    # if OD == 1:
    #     min_label_counts = labels.value_counts().nsmallest(rare_types)
    #     print("数量最少的 %d 个类别及其数量：" % rare_types)
    #     print(min_label_counts)
    #     new_labels = np.zeros(len(labels))
    #     for label in min_label_counts.index:
    #         new_labels[labels == label] = 1
    #     train_data, test_data = data_array, data_array
    #     train_samples = train_data.shape[0]
    #     test_samples = test_data.shape[0]
    #     print("Features: ", str(train_data.shape[1]))
    #     print("Train Samples: ", train_samples)
    #     print("Test Samples: ", test_samples)
    #     # train_set,test_set=CustomDataset(np.array(train_data), new_labels), CustomDataset(np.array(test_data), new_labels)
    #     return data_array, new_labels
    #
    # if OD == 0:
    #     # labels = labels.reset_index(drop=True)
    #     min_label_counts = label_counts.nsmallest(rare_types)
    #     print("数量最少的 %d 个类别及其数量：" % rare_types)
    #     print(min_label_counts)
    #     min_label_indices = min_label_counts.index
    #     normal_data = data_array[~labels.isin(min_label_indices)]
    #     ab_data = data_array[labels.isin(min_label_indices)]
    #     train_data, test_normal_data = train_test_split(normal_data, test_size=0.25, random_state=42)
    #     test_data = np.concatenate((test_normal_data, ab_data))
    #     test_label = np.concatenate((np.zeros(test_normal_data.shape[0]), np.ones(ab_data.shape[0])))
    #     train_samples = train_data.shape[0]
    #     test_samples = test_data.shape[0]
    #     print("Features: ", str(train_data.shape[1]))
    #     print("Train Samples: ", train_samples)
    #     print("Test Samples: ", test_samples)
    #     # train_data,test_data=data_preprocessing("1",train_data,test_data)
    #     train_set, test_set = CustomDataset(np.array(train_data), np.zeros(train_data.shape[0])), CustomDataset(
    #         np.array(test_data), test_label)
    #     return train_set, test_set

if __name__ == "__main__":
    path = "/mnt/data01/public/aad_data/gene/data"
    names = [
        'Campbell',
        #      'PBMC68K',
        #      'Mouse_retina',
        #      'Baron Human'
    ]
    for name in names:
        handle_to_dataset(path, name)
