import os
import pandas as pd
import numpy as np
from scipy.stats import norm

gender_map = {
    1: 'Female',
    2: 'Male'
}

label_map = {
    1: 'CN',
    2: 'MCI',
    3: 'AD'
}

# 脑区映射（占位符，AIBL没有这些数据）
region_map = {
    'Ventricles': 'ventricular',
    'Hippocampus': 'hippocampal',
    'WholeBrain': 'whole brain',
    'Entorhinal': 'entorhinal cortex',
    'Fusiform': 'fusiform gyrus',
    'MidTemp': 'middle temporal gyrus'
}

# 医疗史名称映射
medhis_map = {
    "MHPSYCH": "Psychiatric",
    "MH2NEURL": "Neurologic",
    "MH4CARD": "Cardiovascular",
    "MH6HEPAT": "Hepatic",
    "MH8MUSCL": "Musculoskeletal",
    "MH9ENDO": "Endocrine-Metabolic",
    "MH10GAST": "Gastrointestinal",
    "MH12RENA": "Renal-Genitourinary",
    "MH16SMOK": "Smoking",
    "MH17MALI": "Malignancy"
}

file_dir = '/data/qiuhui/data/AIBL/Data_extract_3.3.0'

text_dict = {}
df = pd.read_csv(os.path.join(file_dir,'aibl_pdxconv_01-Jun-2018.csv'))
for idx in range(len(df)):
    rid = df.iloc[idx].RID
    month = df.iloc[idx].VISCODE
    label = df.iloc[idx].DXCURREN
    if label in label_map.keys():
        label = label_map[label]
    else:
        continue

    res = {
        'label': label,
    }

    if rid not in text_dict.keys():
        text_dict[rid] = {
            month:res
        }
    else:
        text_dict[rid][month] = res

df = pd.read_csv(os.path.join(file_dir,'aibl_ptdemog_01-Jun-2018.csv'))
for rid in text_dict.keys():
    for month in text_dict[rid].keys():
        new_df = df[df.RID==rid]
        if len(new_df) != 0:
            text_dict[rid][month]['gender'] = gender_map[new_df.iloc[0].PTGENDER]
            text_dict[rid][month]['birth_year'] = new_df.iloc[0].PTDOB[1:]


df = pd.read_csv(os.path.join(file_dir,'aibl_medhist_01-Jun-2018.csv'))
for rid in text_dict.keys():
    for month in text_dict[rid].keys():
        new_df = df[df.RID==rid]
        if len(new_df) != 0:
            text_dict[rid][month]['medical_history'] = {
                'MHPSYCH': new_df.iloc[0].MHPSYCH,
                'MH2NEURL': new_df.iloc[0].MH2NEURL,
                'MH4CARD': new_df.iloc[0].MH4CARD,
                'MH6HEPAT': new_df.iloc[0].MH6HEPAT,
                'MH8MUSCL': new_df.iloc[0].MH8MUSCL,
                'MH9ENDO': new_df.iloc[0].MH9ENDO,
                'MH10GAST': new_df.iloc[0].MH10GAST,
                'MH12RENA': new_df.iloc[0].MH12RENA,
                'MH16SMOK': new_df.iloc[0].MH16SMOK,
                'MH17MALI': new_df.iloc[0].MH17MALI,
            }

df = pd.read_csv(os.path.join(file_dir,'aibl_mmse_01-Jun-2018.csv'))
for rid in text_dict.keys():
    for month in text_dict[rid].keys():
        new_df = df[df.RID==rid]
        if len(new_df) > 0:
            # 找到匹配的VISCODE
            visit_df = new_df[new_df.VISCODE == month]
            if len(visit_df) > 0:
                text_dict[rid][month]['mmse'] = visit_df.iloc[0].MMSCORE
                exam_date = visit_df.iloc[0].EXAMDATE
                if isinstance(exam_date, str) and '/' in exam_date:
                    exam_year = exam_date.split('/')[-1]
                    text_dict[rid][month]['exam_year'] = exam_year
                    if 'birth_year' in text_dict[rid][month]:
                        birth_year = text_dict[rid][month]['birth_year']
                        if birth_year.isdigit():
                            age = int(exam_year) - int(birth_year)
                            text_dict[rid][month]['age'] = age

df = pd.read_csv(os.path.join(file_dir,'aibl_neurobat_01-Jun-2018.csv'))
for rid in text_dict.keys():
    for month in text_dict[rid].keys():
        new_df = df[df.RID==rid]
        if len(new_df) > 0:
            visit_df = new_df[new_df.VISCODE == month]
            if len(visit_df) > 0:
                text_dict[rid][month]['LIMMTOTAL'] = visit_df.iloc[0].LIMMTOTAL
                text_dict[rid][month]['LDELTOTAL'] = visit_df.iloc[0].LDELTOTAL


df = pd.read_csv(os.path.join(file_dir,'aibl_labdata_01-Jun-2018.csv'))
for rid in text_dict.keys():
    for month in text_dict[rid].keys():
        new_df = df[df.RID==rid]
        if len(new_df) > 0:
            visit_df = new_df[new_df.VISCODE == month]
            if len(visit_df) > 0:
                text_dict[rid][month]['biospecimen'] = {
                    'AXT117': visit_df.iloc[0].AXT117,  # Thyroid Stim Hormone
                    'BAT126': visit_df.iloc[0].BAT126,  # Vitamin B12
                    'HMT100': visit_df.iloc[0].HMT100,  # MCH
                    'HMT102': visit_df.iloc[0].HMT102,  # MCHC
                    'HMT13': visit_df.iloc[0].HMT13,    # Platelets
                    'HMT3': visit_df.iloc[0].HMT3,      # RBC
                    'HMT40': visit_df.iloc[0].HMT40,    # Hemoglobin
                    'HMT7': visit_df.iloc[0].HMT7,      # WBC
                    'RCT11': visit_df.iloc[0].RCT11,    # Serum Glucose
                    'RCT20': visit_df.iloc[0].RCT20,    # Cholesterol
                    'RCT392': visit_df.iloc[0].RCT392,  # Creatinine
                    'RCT6': visit_df.iloc[0].RCT6,      # Urea Nitrogen
                }

df = pd.read_csv(os.path.join(file_dir,'aibl_apoeres_01-Jun-2018.csv'))
for rid in text_dict.keys():
    for month in text_dict[rid].keys():
        new_df = df[df.RID==rid]
        if len(new_df) > 0:
            gene1 = new_df.iloc[0].APGEN1
            gene2 = new_df.iloc[0].APGEN2
            if gene1 == 4 and gene2 == 4:
                num = 2
            elif (gene1 == 4 and gene2 != 4) or (gene1 != 4 and gene2 == 4):
                num = 1
            else:
                num = 0
            text_dict[rid][month]['apoe'] = num

# 创建实验室参考组数据 (健康对照组CN的数据)
print("Creating reference group for lab data Z-scores...")

# 定义实验室字段
lab_fields = [
    'AXT117', 'BAT126', 'HMT100', 'HMT102', 'HMT13', 
    'HMT3', 'HMT40', 'HMT7', 'RCT11', 'RCT20', 'RCT392', 'RCT6'
]

# 收集所有CN的实验室数据
cn_lab_data = []
for rid in text_dict.keys():
    for month, data in text_dict[rid].items():
        if data.get('label') == 'CN' and 'biospecimen' in data:
            lab_data = data['biospecimen'].copy()
            lab_data['rid'] = rid
            lab_data['month'] = month
            cn_lab_data.append(lab_data)

# 转换为DataFrame
cn_lab_df = pd.DataFrame(cn_lab_data)

# 计算每个实验室指标的参考统计量
lab_reference_stats = {}
for field in lab_fields:
    if field not in cn_lab_df.columns:
        continue
        
    # 清理数据：将非数值转换为NaN，-4视为缺失值（AIBL使用-4表示缺失）
    clean_data = pd.to_numeric(cn_lab_df[field], errors='coerce')
    clean_data = clean_data.replace(-4, np.nan).dropna()
    
    # 确保有足够的数据点
    if len(clean_data) > 5:
        # 移除异常值
        q1 = clean_data.quantile(0.25)
        q3 = clean_data.quantile(0.75)
        iqr = q3 - q1
        lower_bound = q1 - 1.5 * iqr
        upper_bound = q3 + 1.5 * iqr
        
        # 过滤掉异常值
        filtered_data = clean_data[
            (clean_data >= lower_bound) & 
            (clean_data <= upper_bound)
        ]
        
        # 计算统计量
        if len(filtered_data) > 5:
            mean = filtered_data.mean()
            std = filtered_data.std()
            lab_reference_stats[field] = (mean, std)
            print(f"Calculated reference for {field}: mean={mean:.2f}, std={std:.2f} (based on {len(filtered_data)} samples)")

# 计算实验室数据Z-score的函数
def calculate_lab_zscore(field, value):
    # 确保值是数值类型
    try:
        num_value = float(value)
    except (ValueError, TypeError):
        return None
    
    if pd.isna(num_value) or field not in lab_reference_stats:
        return None
    
    mean, std = lab_reference_stats[field]
    
    if std == 0 or pd.isna(std):  # 避免除以零或NaN
        return None
    
    return (num_value - mean) / std

# 生成实验室数据描述的函数
def generate_lab_description(field, value, zscore):
    # 定义实验室指标名称映射
    lab_name_map = {
        'AXT117': 'Thyroid Stim. Hormone',
        'BAT126': 'Vitamin B12',
        'HMT10': 'Monocytes',
        'HMT100': 'MCH',
        'HMT102': 'MCHC',
        'HMT11': 'Eosinophils',
        'HMT12': 'Basophils',
        'HMT13': 'Platelets',
        'HMT15': 'Neutrophils',
        'HMT16': 'Lymphocytes',
        'HMT17': 'Monocytes',
        'HMT18': 'Eosinophils',
        'HMT19': 'Basophils',
        'HMT2': 'Hematocrit',
        'HMT3': 'RBC',
        'HMT40': 'Hemoglobin',
        'HMT7': 'WBC',
        'HMT8': 'Neutrophils',
        'HMT9': 'Lymphocytes',
        'RCT1': 'Total Bilirubin',
        'RCT11': 'Serum Glucose',
        'RCT12': 'Total Protein',
        'RCT13': 'Albumin',
        'RCT14': 'Creatine Kinase',
        'RCT1407': 'Alkaline Phosphatase',
        'RCT1408': 'LDH',
        'RCT183': 'Calcium (EDTA)',
        'RCT19': 'Triglycerides (GPO)',
        'RCT20': 'Cholesterol (High Performance)',
        'RCT29': 'Direct Bilirubin',
        'RCT3': 'GGT',
        'RCT392': 'Creatinine (Rate Blanked)',
        'RCT4': 'ALT (SGPT)',
        'RCT5': 'AST (SGOT)',
        'RCT6': 'Urea Nitrogen',
        'RCT8': 'Serum Uric Acid',
        'RCT9': 'Phosphorus'
    }
 
    
    name = lab_name_map.get(field, field)
    
    # 安全地格式化数值
    try:
        num_value = float(value)
        formatted_value = f"{num_value:.2f}"
    except (ValueError, TypeError):
        return f"{name}: {value}"

    # 确定严重程度描述   
    severity = ''
    direction = ''

    if zscore is not None:
        if abs(zscore) > 3:
            severity = "profound"
        elif abs(zscore) > 2:
            severity = "significant"
        elif abs(zscore) > 1.5:
            severity = "moderate"
        elif abs(zscore) > 1:
            severity = "mild"
        else:
            severity = "normal"
        
        # 确定变化方向
        if (zscore > 1):
            direction = "elevated"
        elif (zscore < -1):
            direction = "reduced"
        else:
            direction = ""
        
    return f"{name}: {formatted_value} ({severity} {direction})"

to_write = []
aibl_root = '/data/qiuhui/data/AIBL/images_skull_stripping_and_align'
img_list = sorted(os.listdir(aibl_root))

for img_name in img_list:
    if not img_name.endswith('.nii.gz'):
        continue
        
    file_path = os.path.join(aibl_root, img_name)
    try:
        base_name = img_name.split('.nii.gz')[0]
        ptid_str, img_month = base_name.split('_', 1)
        rid = int(ptid_str)
    except:
        continue
        
    if rid not in text_dict:
        continue
        
    month_data = text_dict[rid].get(img_month)
    if not month_data:
        continue
        
    # 构建文本描述
    text = ""
    
    # 1. 添加基本信息
    if 'age' in month_data and not pd.isna(month_data['age']):
        text += f"Age is {month_data['age']:.1f} years. "
    if 'gender' in month_data and not pd.isna(month_data['gender']):
        text += f"Gender is {month_data['gender']}. "
    
    # 2. 添加医疗史
    if 'medical_history' in month_data:
        med_items = []
        for field, value in month_data['medical_history'].items():
            if pd.notna(value) and value != -1:
                med_name = medhis_map.get(field, field)
                if value == 1:
                    med_items.append(f"{med_name}")
        
        if med_items:
            text += "Medical history: " + "; ".join(med_items) + ". "
    
    # 3. 添加认知测试
    if 'mmse' in month_data and not pd.isna(month_data['mmse']) and month_data['mmse'] != -4:
        text += f"MMSE: {month_data['mmse']}. "
    if 'LDELTOTAL' in month_data and not pd.isna(month_data['LDELTOTAL']) and month_data['LDELTOTAL'] != -4:
        text += f"Delayed Recall: {month_data['LDELTOTAL']}. "
    if 'LIMMTOTAL' in month_data and not pd.isna(month_data['LIMMTOTAL']) and month_data['LIMMTOTAL'] != -4:
        text += f"Immediate Recall: {month_data['LIMMTOTAL']}. "
    
    # 4. 添加实验室数据 - 只报告异常值
    if 'biospecimen' in month_data:
        lab_items = []
        for test, value in month_data['biospecimen'].items():
            if pd.notna(value) and value != -4.0:
                try:
                    num_value = float(value)
                except (ValueError, TypeError):
                    continue
                
                # 计算Z-score
                zscore = calculate_lab_zscore(test, num_value)
                
                # 只显示异常值 (|Z| > 1)
                if zscore is not None and abs(zscore) > 1:
                    lab_desc = generate_lab_description(test, num_value, zscore)
                    lab_items.append(lab_desc)
        
        if lab_items:
            text += "Laboratory findings: " + "; ".join(lab_items) + ". "
    
    # 5. 添加APOE
    if 'apoe' in month_data and not pd.isna(month_data['apoe']):
        text += f"APOEε4 alleles: {month_data['apoe']}. "
    
    # 图像发现（占位符 - AIBL没有这些数据）
    imgfinding = "Image findings: No volumetric data available."
    
    # 诊断标签
    diagnosis = f"Diagnosis: {month_data['label']}."
    
    # 添加到输出列表（四列格式）
    to_write.append(f"{file_path}\t{imgfinding}\t{text}\t{diagnosis}\n")

# 写入输出文件
with open('./AIBL2.csv', 'w') as f:
    for line in to_write:
        f.write(line)

print(f"Processing completed. {len(to_write)} records written to AIBL2.csv")