import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import re
from collections import defaultdict

# 读取两个数据集
adni_df = pd.read_csv('/data/qiuhui/code/ADReasoning/local_data/ADNI4.csv', sep='\t', header=None, 
                      names=['img_path', 'img_finding', 'text', 'diagnosis'])
aibl_df = pd.read_csv('/data/qiuhui/code/ADReasoning/local_data/AIBL2.csv', sep='\t', header=None,names=['img_path', 'img_finding', 'text', 'diagnosis'])

# 添加数据集来源标记
adni_df['dataset'] = 'ADNI'
aibl_df['dataset'] = 'AIBL'

# 合并数据集
combined_df = pd.concat([adni_df, aibl_df], ignore_index=True)

# 初始化统计数据
diagnosis_counts = defaultdict(int)
unique_elements_per_category = defaultdict(set)
element_presence_counts = defaultdict(lambda: defaultdict(int))
total_records = len(combined_df)

# 分析每条记录
for _, row in combined_df.iterrows():
    text = row['text']
    try:
        diagnosis = row['diagnosis'].replace('Diagnosis:', '').replace('.', '').strip()
    except:
        import pdb;pdb.set_trace()
    diagnosis_counts[diagnosis] += 1
    
    # 1. 人口统计信息
    demo_section = text.split("Medical history:")[0] if "Medical history:" in text else text
    demo_items = [
        item.strip().split(':')[0].split(' is')[0]
        for item in demo_section.split('.') 
        if any(keyword in item for keyword in ['Age', 'Gender', 'Education', 'Handedness', 'Race'])
    ]
    for item in demo_items:
        unique_elements_per_category['Demographics'].add(item)
        element_presence_counts['Demographics'][item] += 1
    
    # 2. 病史信息
    if "Medical history:" in text:
        med_section = text.split("Medical history:")[1].split('.')[0]
        med_items = [item.strip() for item in med_section.split('; ') if item.strip()]
        for item in med_items:
            unique_elements_per_category['Medical History'].add(item)
            element_presence_counts['Medical History'][item] += 1
    
    # 3. 认知测试
    cog_items = []
    if "MMSE:" in text or "MMSE score is" in text:
        cog_items.append("MMSE")
    if "MoCA:" in text:
        cog_items.append("MoCA")
    if "Logical Memory:" in text:
        cog_items.append("Logical Memory")
    if "Immediate Recall:" in text:
        cog_items.append("Immediate Recall")
    if "Delayed Recall:" in text:
        cog_items.append("Delayed Recall")
    for item in cog_items:
        unique_elements_per_category['Cognitive Tests'].add(item)
        element_presence_counts['Cognitive Tests'][item] += 1
    
    # 4. 实验室数据 (原Biospecimens)
    if "Laboratory findings:" in text:
        lab_section = text.split("Laboratory findings:")[1].split('. ')[0]
        lab_items = [item.split(':')[0].strip() for item in lab_section.split('; ') if ':' in item]
        for item in lab_items:
            unique_elements_per_category['Laboratory Data'].add(item)
            element_presence_counts['Laboratory Data'][item] += 1
    
    # 5. 生物标志物 (新增类别)
    if "Biomarker levels:" in text:
        bio_section = text.split("Biomarker levels:")[1].split('. ')[0]
        bio_items = [item.split(':')[0].strip() for item in bio_section.split('; ') if ':' in item]
        for item in bio_items:
            unique_elements_per_category['Biomarkers'].add(item)
            element_presence_counts['Biomarkers'][item] += 1
    
    # 6. 基因数据
    if "APOEε4 alleles:" in text or "APOE is" in text:
        unique_elements_per_category['Genetic Data'].add("APOEε4 alleles")
        element_presence_counts['Genetic Data']["APOEε4 alleles"] += 1
    
    # 7. 图像发现
    img_finding = row['img_finding']
    if pd.isna(img_finding):
        pass
    elif "No volumetric data available" not in str(img_finding):
        regions = re.findall(r'([a-zA-Z\s]+volume measures)', str(img_finding))
        regions = [r.strip() for r in regions if r.strip()]
        for region in regions:
            unique_elements_per_category['Image Findings'].add(region)
            element_presence_counts['Image Findings'][region] += 1

# 计算每个类别的唯一元素数量
unique_counts = {category: len(elements) for category, elements in unique_elements_per_category.items()}

# 转换为Series以便绘图
diagnosis_series = pd.Series(diagnosis_counts)
unique_series = pd.Series(unique_counts)

# ================== 科学可视化RGB色板 (0-255范围) ==================
# 1. 诊断分布饼图 - 科学蓝色系
sci_diagnosis_colors = [
    (158, 202, 225),  # 浅蓝色 - 用于CN
    (107, 174, 214),  # 中蓝色 - 用于SCD
    (66, 146, 198),   # 蓝色 - 用于MCI
    (33, 113, 181),   # 深蓝色 - 用于AD
    (8, 69, 148)      # 深蓝 - 用于其他
]

# 2. 类别唯一元素饼图 - 科学绿色系
sci_category_colors = [
    (197, 224, 180),  # 薄荷绿 - Demographics
    (169, 209, 142),  # 春绿 - Medical History
    (141, 183, 124),  # 橄榄绿 - Cognitive Tests
    (112, 173, 71),   # 叶绿 - Laboratory Data
    (84, 130, 53),    # 森林绿 - Biomarkers
    (56, 87, 35),     # 松绿 - Genetic Data
    (28, 52, 18)      # 墨绿 - Image Findings
]

# 3. 元素存在柱状图 - 科学分类色系
category_base_colors = {
    'Demographics': (78, 121, 167),   # 蓝色
    'Medical History': (242, 142, 43),  # 橙色
    'Cognitive Tests': (89, 161, 79),  # 绿色
    'Laboratory Data': (237, 201, 72),  # 黄色
    'Biomarkers': (225, 87, 89),      # 红色
    'Genetic Data': (176, 122, 161),   # 紫色
    'Image Findings': (118, 183, 178)  # 青色
}

# 将RGB值转换为0-1范围用于matplotlib
def rgb_to_normal(rgb_tuple):
    return tuple(c/255.0 for c in rgb_tuple)

# 转换所有颜色
sci_diagnosis_colors_norm = [rgb_to_normal(c) for c in sci_diagnosis_colors]
sci_category_colors_norm = [rgb_to_normal(c) for c in sci_category_colors]
category_base_colors_norm = {k: rgb_to_normal(v) for k, v in category_base_colors.items()}

# ================== 绘图配置 ==================
# 创建图表
plt.figure(figsize=(28, 10))

# 设置全局字体
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 12

# 1. 诊断分布饼图 (左侧)
ax1 = plt.subplot(1, 3, 1)
wedges1, texts1, autotexts1 = ax1.pie(
    diagnosis_series,
    startangle=90,
    colors=sci_diagnosis_colors_norm[:len(diagnosis_series)],
    wedgeprops={'linewidth': 1, 'edgecolor': 'white'},
    autopct='%1.1f%%',
    pctdistance=0.8,
    textprops={'fontsize': 10, 'fontweight': 'bold'}
)

# 将图例放在右侧
ax1.legend(wedges1, diagnosis_series.index, 
          title="Diagnosis Types",
          loc="center left",
          bbox_to_anchor=(1, 0, 0.5, 1),
          prop={'size': 14})
ax1.set_title('Diagnosis Distribution', fontsize=14, fontweight='bold')

# 2. 各类别唯一元素数量饼图 (中间)
ax2 = plt.subplot(1, 3, 2)
# 确保颜色顺序与类别顺序匹配
category_order = ['Demographics', 'Medical History', 'Cognitive Tests', 'Laboratory Data', 'Biomarkers', 'Genetic Data', 'Image Findings']
colors_for_pie = [sci_category_colors_norm[category_order.index(cat)] for cat in unique_series.index]

wedges2, texts2, autotexts2 = ax2.pie(
    unique_series,
    startangle=90,
    colors=colors_for_pie,
    wedgeprops={'linewidth': 1, 'edgecolor': 'white'},
    autopct='%1.1f%%',
    pctdistance=0.8,
    textprops={'fontsize': 10, 'fontweight': 'bold'}
)

# 调整百分比文本颜色
for i, autotext in enumerate(autotexts2):
    # 根据背景颜色调整文本颜色
    if i in [0, 1,2,6]:  # 浅色背景用黑色文本
        autotext.set_color('black')
    else:
        autotext.set_color('white')

# 将图例放在右侧
ax2.legend(wedges2, unique_series.index, 
          title="Information Categories",
          loc="center left",
          bbox_to_anchor=(1, 0, 0.5, 1),
          prop={'size': 14})
ax2.set_title('Unique Elements per Category', fontsize=14, fontweight='bold')

# 3. 元素存在情况横向柱状图 (右侧)
ax3 = plt.subplot(1, 3, 3)

# 准备柱状图数据：每种类别取前5个元素
bar_data = []
element_labels = []
bar_colors = []
category_labels = []
category_positions = []

for category in category_order:
    if category in element_presence_counts:
        base_color = category_base_colors_norm.get(category, (0.31, 0.47, 0.65))
        r, g, b = base_color
        
        category_positions.append(len(bar_data))
        
        elements = element_presence_counts[category]
        top_elements = sorted(elements.items(), key=lambda x: x[1], reverse=True)[:5]
        
        for i, (element, count) in enumerate(top_elements):
            bar_data.append(count)
            element_labels.append(element)
            # 创建渐变色，保持色调一致但改变亮度
            shade = 0.7 + (i * 0.06)
            bar_colors.append((min(r * shade, 1.0), min(g * shade, 1.0), min(b * shade, 1.0)))
            category_labels.append(category)
        
        # 在类别之间添加空行分隔
        if category != 'Image Findings':
            bar_data.append(0)
            element_labels.append("")
            bar_colors.append("white")
            category_labels.append("")

# 创建横向柱状图
y_pos = np.arange(len(bar_data))
ax3.barh(y_pos, bar_data, align='center', color=bar_colors, edgecolor='grey', linewidth=0.5, height=0.5)
ax3.set_yticks(y_pos)
ax3.set_yticklabels(element_labels, fontsize=9)

# 添加类别标签
for pos in category_positions:
    category = category_labels[pos]
    category_middle = pos + 2
    ax3.text(-max(bar_data)*0.05, category_middle, category, 
            ha='right', va='center', fontweight='bold', 
            fontsize=10, color=category_base_colors_norm[category])

ax3.invert_yaxis()
ax3.set_xlabel('Number of Records Present', fontweight='bold')
ax3.set_title('Top 5 Elements per Category', fontsize=14, fontweight='bold')
ax3.grid(axis='x', linestyle='--', alpha=0.6)
ax3.set_xlim(left=-max(bar_data)*0.1, right=max(bar_data)*1.15)

# 添加整体标题
plt.suptitle('ADNI and AIBL Dataset Analysis: Information Distribution and Presence', 
             fontsize=16, fontweight='bold', y=0.98)

# 调整布局
plt.tight_layout(rect=[0, 0, 1, 0.96], pad=3.0)

# 保存图表
plt.savefig('combined_dataset_analysis_with_biomarkers.png', dpi=300, bbox_inches='tight')
plt.show()

# 打印统计摘要
print("="*80)
print(f"Total Records Analyzed: {total_records}")
print(f"ADNI Records: {len(adni_df)}")
print(f"AIBL Records: {len(aibl_df)}")
print("\nDiagnosis Distribution:")
print(diagnosis_series.to_string())
print("\nUnique Elements per Category:")
print(unique_series.to_string())
print("\nTop Elements per Category:")
for category in category_order:
    if category in element_presence_counts:
        print(f"\n{category}:")
        elements = element_presence_counts[category]
        top_elements = sorted(elements.items(), key=lambda x: x[1], reverse=True)[:5]
        for element, count in top_elements:
            print(f"  {element}: {count} ({count/total_records*100:.1f}%)")
print("="*80)