import gseapy as gp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# 定义基因列表
genes = pd.read_csv('./data/uq1000_feature.csv').columns.tolist()[1:]
class_label = []
unique_class = []
# KEGG_2016
# KEGG_2021_Human
# 使用gseapy进行基因集富集分析
#['KEGG_2021_Human', 'Reactome_2021', 'GO_Biological_Process_2021']
enr = gp.enrichr(gene_list=genes,
                 gene_sets='GO_Biological_Process_2021',  # 使用KEGG通路数据库
                 organism='Human',  # 物种
                 outdir='enrichr_go_biological_process',  # 输出目录
                 cutoff=0.05  # p值截断
                )
results1 = enr.results

# enr1 = gp.enrichr(gene_list=genes,
#                  gene_sets='KEGG_2021_Human',  # 使用KEGG通路数据库
#                  organism='Human',  # 物种
#                  outdir='enrichr_kegg',  # 输出目录
#                  cutoff=0.05  # p值截断
#                 )
# results2 = enr1.results

top_results = results1.head(10)

# 准备数据
labels = top_results['Term']
sizes = top_results['Overlap'].apply(lambda x: int(x.split('/')[0]))  # 提取基因数量
sizes = sizes / sizes.sum()  # 归一化，使得总和为1

# 自定义颜色映射1
cmap = plt.cm.Reds  # 使用红色调色板
norm = plt.Normalize(vmin=min(sizes), vmax=max(sizes))
colors = cmap(norm(sizes))

# 创建一个环形图
plt.figure(figsize=(8, 8))
wedges, texts, autotexts = plt.pie(sizes, labels=None, autopct='%1.1f%%', startangle=140, colors=colors, wedgeprops=dict(width=0.3))

# 调整百分比字体大小
for autotext in autotexts:
    autotext.set_fontsize(20)  # 设置字体大小为12，可以根据需要调整
# 添加中心的白色圆圈
centre_circle = plt.Circle((0,0),0.70,fc='white')
fig = plt.gcf()
fig.gca().add_artist(centre_circle)

# 确保环形图是圆形的
plt.axis('equal')
#plt.title('Top 10 Enriched GO Biological Processes')
plt.savefig('./enrichr_go_biological_process/go.jpg', dpi=1000, bbox_inches='tight')
plt.show()


# results = pd.concat([results1,results2])
#
# x = results[['Term', 'Overlap', 'Adjusted P-value', 'Genes']].values.tolist()
# num_null = 0
# for i in range(len(genes)):
#     gene = genes[i]
#     label = ' '
#     label_len = 0
#     for s in x:
#         s1 = s[3].split(';')
#         if(gene in s1):
#             if(s[0] == 'Pathways in cancer'):
#                 label = s[0]
#                 label_len = len(s1)
#                 break
#             elif(s[0] == 'PI3K-Akt-mTOR'):
#                 label = 'Pathways in cancer'
#                 label_len = len(s1)
#                 break
#             elif (s[0] == 'Wnt'):
#                 label = 'Pathways in cancer'
#                 label_len = len(s1)
#                 break
#             elif (s[0] == 'Notch'):
#                 label = 'Pathways in cancer'
#                 label_len = len(s1)
#                 break
#             elif (s[0] == 'Hedgehog'):
#                 label = 'Pathways in cancer'
#                 label_len = len(s1)
#                 break
#             elif(label_len < len(s1)):
#                 label = s[0]
#                 label_len = len(s1)
#     if(label_len == 0):
#         num_null += 1
#     class_label.append(label)
#
# final_class_label = []

# gene_class = pd.read_csv('./used_gene_class.csv', sep='>')
# for i in range(len(genes)):
#     gene = genes[i]
#     label = class_label[i]
#     gene_path = gene_class.values.tolist()
#     iffind = False
#     for j in range(len(gene_path)):
#         y = gene_path[j][1]
#         y = y.split(';')
#         if(label in y):
#             label = gene_path[j][0]
#             iffind = True
#             break
#     if(iffind):
#         final_class_label.append(label)
#
# unique_class = list(set(final_class_label))
# unique_class_num = []
# value_10 = 0
# for i in range(len(unique_class)):
#     num = 0
#     for j in final_class_label:
#         if(unique_class[i] == j):
#             num += 1
#     if(num<=2):
#         value_10 += 1
#     unique_class_num.append(num)
# # print(results[['Term', 'Overlap', 'Adjusted P-value', 'Genes']])
# split_cells = pd.DataFrame({
#     'genes': genes,
#     'classify': final_class_label
# })
# split_cells.to_csv('./data/split_cell_lines.csv',index=False)
# print('gdueiw')