import numpy as np
import pandas as pd
import scanpy as sc
import random

import os

# 获取当前文件所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
current_dir = current_dir[:current_dir.find('causally')]
root_path = os.path.join(current_dir, 'dataset')
# adata = sc.read("/Users/mac/PycharmProjects/casual/Causally_Rebuttle/dataset/kang_count.h5ad")
adata = sc.read(os.path.join(root_path, 'kang_count.h5ad'))
saved_path = os.path.join(root_path, 'kang')
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=1000)

# adata = adata[:, adata.var['highly_variable']]
print(adata)
print(adata.obs['condition'].value_counts())
n_conditions = adata.obs["condition"].unique().shape[0]
cell_types = [
['B','CD4 T'],
['NK','CD8 T'],
['B','CD14 Mono'],
['B','CD16 Mono'],
['B','DC'],
['B','NK'],
['B','T'],
['CD16 Mono','T'],
['CD4 T','T'],
['CD8 T','T'],
]

treatment_index = {
    'control':0,
    'stimulated':1,
}
index = 1
for cell in cell_types:
    print(cell)
    cell_index = {
        cell[0]:0,
        cell[1]:1,
    }
    adata_train = adata[adata.obs["cell_type"].isin(cell)]
    data = {
        'treatment':[treatment_index[x] for x in adata_train.obs['condition'].to_numpy()],
        'yf':[cell_index[x] for x in adata_train.obs['cell_type'].to_numpy()],
        'e': np.random.choice([0,1],size=len(adata_train))
    }
    # for i in range(adata_train.n_vars):
    #     data[f'x{i+1}'] = adata_train.X[:,i]

    X = adata_train.X
    df_X = pd.DataFrame(X,columns=[f'x{i+1}' for i in range(X.shape[1])])
    df = pd.DataFrame(data)
    df = pd.concat([df,df_X],axis=1)
    # print(df)
    df.to_csv(f"{saved_path}/kang{index}.csv", index=False)
    # print(f"{saved_path}/kang{index}.csv")
    index += 1