import pandas as pd
import numpy as np
from datetime import datetime
import os
import pickle

from collections import defaultdict, Counter


def compute_readmissions(admissions):
    admissions = sorted(admissions, key=lambda x: x['admission_time'])
    count, flag = 0, False
    for i in range(len(admissions) - 1):
        if (admissions[i+1]['admission_time'] - admissions[i]['discharge_time']).days <= 15:
            count += 1
            flag = True
    return count, flag


def describe_patient_statistics(patient_info_all, patient_admission):
    gender_counts = {}
    age_group_counts = {"<18": 0, "18-30": 0, "30-60": 0, ">60": 0}
    gender_age_counts = {}

    for pid, admissions in patient_admission.items():
        patient_info = patient_info_all.get(pid)
        if not patient_info:
            continue
        gender = patient_info.get("GENDER", "Unknown")
        gender_counts[gender] = gender_counts.get(gender, 0) + 1

        latest_admission_time = max(adm['admission_time'] for adm in admissions).to_pydatetime()
        dob = patient_info.get("DOB").to_pydatetime()
        age = round((latest_admission_time - dob).days / 365)
        patient_info_all[pid]["AGE"] = age

        if age < 18:
            age_group = "<18"
        elif age < 30:
            age_group = "18-30"
        elif age < 60:
            age_group = "30-60"
        else:
            age_group = ">60"

        age_group_counts[age_group] += 1

        if gender not in gender_age_counts:
            gender_age_counts[gender] = {"<18": 0, "18-30": 0, "30-60": 0, ">60": 0}
        gender_age_counts[gender][age_group] += 1

    # Print statistics
    stats = {
        "gender_counts": gender_counts,
        "age_group_counts": age_group_counts,
        "gender_age_counts": gender_age_counts
    }
    print("Gender:", stats["gender_counts"])
    print("Age Groups:", stats["age_group_counts"])
    print("Gender vs Age Groups:", stats["gender_age_counts"])

    patient_info = {pid: info for pid, info in patient_info_all.items() if pid in patient_admission}

    return patient_info


def has_heart_disease(codes, broad=True):
    if broad:
        ranges = [('410', '414'), ('420', '429')]
    else:
        ranges = [('427', '427'), ('428', '428')]

    for code in codes:
        if code[0].isdigit():  # Ensure code starts with a digit
            prefix = code.split('.')[0]
            for start, end in ranges:
                if start <= prefix <= end:
                    return True
    return False


def has_heart_failure(codes):
    for code in codes:
        if code[0].isdigit():
            if code.split('.')[0] == '428':
                return True
    return False


if __name__ == '__main__':
    data_path = '../../DG-EHR/data'
    dataset = 'mimic3'  # 'mimic3', 'mimic4'
    dataset_path = os.path.join(data_path, dataset)
    raw_path = os.path.join(dataset_path, 'raw')
    parsed_path = os.path.join(dataset_path, 'parsed')

    patients = pd.read_csv(os.path.join(raw_path, 'PATIENTS.csv'),
                           usecols=['SUBJECT_ID', 'GENDER', 'DOB'],
                           converters={'SUBJECT_ID': int, 'Gender': str, 'ADMITTIME': \
                               lambda cell: datetime.strptime(np.str_(cell), '%Y-%m-%d %H:%M:%S')}).dropna()
    # Convert patients into a dictionary
    patients = patients.set_index('SUBJECT_ID').to_dict(orient='index')
    for pid, info in patients.items():
        if not isinstance(info.get("DOB"), pd.Timestamp):
            info["DOB"] = pd.Timestamp(info["DOB"])
    print("Total Patients", len(patients))
    print(patients[98768])

    patient_admission = pickle.load(open(os.path.join(parsed_path, 'patient_admission.pkl'), 'rb'))
    print("Total Patients in trained data", len(patient_admission))
    describe_patient_statistics(patients, patient_admission)

    #  Check demographic information in patient demographics
    demos = pd.read_csv(os.path.join(raw_path, 'ADMISSIONS.csv'),
                        usecols=['SUBJECT_ID', 'HADM_ID', 'ETHNICITY', 'INSURANCE', 'LANGUAGE',
                                 'RELIGION', 'MARITAL_STATUS'],
                        converters={'SUBJECT_ID': int, 'HADM_ID': int, 'ETHNICITY': str,
                                    'INSURANCE': str, 'LANGUAGE': str, 'RELIGION': str,
                                    'MARTIAL_STATUS': str})

    # Step 1: Construct patient_demo dictionary
    patient_demo = defaultdict(list)
    for _, row in demos.iterrows():
        patient_demo[row['SUBJECT_ID']].append({
            'HADM_ID': row['HADM_ID'],
            'ethnicity': row['ETHNICITY'],
            'insurance': row['INSURANCE'],
            'language': row['LANGUAGE'],
            'religion': row['RELIGION'],
            'marital_status': row['MARITAL_STATUS']
        })

    # Initial Descriptive Report
    total_patients = len(patient_demo)
    total_admissions = len(demos)
    avg_admissions = total_admissions / total_patients

    print(f"Initial Demographic Data Report:")
    print(f"Total Patients: {total_patients}")
    print(f"Total Admissions: {total_admissions}")
    print(f"Average Admissions per Patient: {avg_admissions:.2f}")

    # Step 2: Filtering patients based on patient_admission
    patient_admission_ids = set(patient_admission.keys())
    filtered_patient_demo = {pid: admissions for pid, admissions in patient_demo.items() if
                             pid in patient_admission_ids}

    filtered_patients = len(filtered_patient_demo)
    filtered_admissions = sum(len(adm_list) for adm_list in filtered_patient_demo.values())
    filtered_avg_admissions = filtered_admissions / filtered_patients

    print("\nFiltered Demographic Data Report (Aligned with Training Data):")
    print(f"Total Patients after Filtering: {filtered_patients}")
    print(f"Total Admissions after Filtering: {filtered_admissions}")
    print(f"Average Admissions per Patient after Filtering: {filtered_avg_admissions:.2f}")

    # Step 3: Checking Validity and Value Counts for Each Demographic Feature
    features = ['ethnicity', 'insurance', 'language', 'religion', 'marital_status']
    feature_valid_counts = {}
    feature_no_conflict_counts = {}
    feature_distributions = {}

    for feature in features:
        valid_patients = 0
        no_conflict_patients = 0
        value_counter = Counter()

        for pid, admissions in filtered_patient_demo.items():
            feature_values = set()
            for adm in admissions:
                val = adm[feature]
                if val not in [None, '', 'UNKNOWN', 'UNABLE TO OBTAIN']:
                    feature_values.add(val)

            if len(feature_values) >= 1:
                valid_patients += 1
            if len(feature_values) == 1:
                no_conflict_patients += 1
                value_counter.update(feature_values)

        feature_valid_counts[feature] = valid_patients
        feature_no_conflict_counts[feature] = no_conflict_patients
        feature_distributions[feature] = value_counter

    # Reporting Results
    for feature in features:
        print(f"\nFeature: {feature.capitalize()}")
        print(f"Total patients with at least one valid record: {feature_valid_counts[feature]}")
        print(f"Patients without conflicting values: {feature_no_conflict_counts[feature]}")
        print("Value Distribution among Non-conflicting Patients:")
        dist_df = pd.DataFrame(feature_distributions[feature].items(), columns=['Value', 'Count'])
        dist_df = dist_df.sort_values(by='Count', ascending=False).reset_index(drop=True)
        print(dist_df.head(10))  # Display top 10 values for brevity

    # Drug prescription for group a+b for cardiovascular diseases (410-414, 420-429) (427, 428)
    admission_codes = pickle.load(open(os.path.join(parsed_path, 'admission_codes.pkl'), 'rb'))

    total_patients = len(patient_admission)
    total_admissions = sum(len(adm_list) for adm_list in patient_admission.values())
    avg_admissions_per_patient = total_admissions / total_patients
    broad_patients = set()
    broad_admissions_count = 0
    narrow_patients = set()
    narrow_admissions_count = 0

    # Iterate through patient_admission to determine disease status
    for pid, admissions in patient_admission.items():
        patient_broad_flag = False
        patient_narrow_flag = False

        for adm in admissions:
            adm_id = adm['admission_id']
            codes = admission_codes.get(adm_id, [])

            if not patient_broad_flag and has_heart_disease(codes, broad=True):
                patient_broad_flag = True

            if not patient_narrow_flag and has_heart_disease(codes, broad=False):
                patient_narrow_flag = True

            if patient_broad_flag and patient_narrow_flag:
                break

        if patient_broad_flag:
            broad_patients.add(pid)
            broad_admissions_count += len(admissions)

        if patient_narrow_flag:
            narrow_patients.add(pid)
            narrow_admissions_count += len(admissions)

    avg_broad_admissions = broad_admissions_count / len(broad_patients) if broad_patients else 0
    avg_narrow_admissions = narrow_admissions_count / len(narrow_patients) if narrow_patients else 0

    print("\nDescriptive Statistics Report for Heart Disease:")
    print("===============================================")
    print(f"Overall:")
    print(f"- Total number of patients: {total_patients}")
    print(f"- Total number of admissions: {total_admissions}")
    print(f"- Average admissions per patient: {avg_admissions_per_patient:.2f}\n")

    print(f"Broad Definition of Heart Disease (ICD9: 410-414, 420-429):")
    print(f"- Total number of patients with heart disease: {len(broad_patients)}")
    print(f"- Total admissions for these patients: {broad_admissions_count}")
    print(f"- Average admissions per patient (broad): {avg_broad_admissions:.2f}\n")

    print(f"Narrow Definition of Heart Disease (ICD9: 427, 428):")
    print(f"- Total number of patients with heart disease: {len(narrow_patients)}")
    print(f"- Total admissions for these patients: {narrow_admissions_count}")
    print(f"- Average admissions per patient (narrow): {avg_narrow_admissions:.2f}")

    # Readmission rates in different groups (Age, gender, race) -- patient_info, patient_admission
    patient_info = pickle.load(open(os.path.join(parsed_path, 'patient_info.pkl'), 'rb'))
    print("\n", patient_info[44083])
    print("\n", patient_admission[44083])

    stats_gender = {
        'F': {'patients': 0, 'admissions': 0, 'readmission_visits': 0, 'readmission_patients': 0},
        'M': {'patients': 0, 'admissions': 0, 'readmission_visits': 0, 'readmission_patients': 0}
    }

    stats_age = {
        '<=60': {'patients': 0, 'admissions': 0, 'readmission_visits': 0, 'readmission_patients': 0},
        '>60': {'patients': 0, 'admissions': 0, 'readmission_visits': 0, 'readmission_patients': 0}
    }

    stats_race = {
        'White': {'patients': 0, 'admissions': 0, 'readmission_visits': 0, 'readmission_patients': 0},
        'Non-White': {'patients': 0, 'admissions': 0, 'readmission_visits': 0, 'readmission_patients': 0}
    }

    for pid, info in patient_info.items():
        gender = info.get('GENDER')
        age_group = '<=60' if info.get('AGE', 0) <= 60 else '>60'
        race_group = 'White' if info.get('ethnicity', '').upper() == 'WHITE' else 'Non-White'

        admissions = patient_admission.get(pid, [])
        num_adm = len(admissions)
        readmission_visits, has_readmission = compute_readmissions(admissions)

        if gender in stats_gender:
            stats_gender[gender]['patients'] += 1
            stats_gender[gender]['admissions'] += num_adm
            stats_gender[gender]['readmission_visits'] += readmission_visits
            if has_readmission:
                stats_gender[gender]['readmission_patients'] += 1

        stats_age[age_group]['patients'] += 1
        stats_age[age_group]['admissions'] += num_adm
        stats_age[age_group]['readmission_visits'] += readmission_visits
        if has_readmission:
            stats_age[age_group]['readmission_patients'] += 1

        stats_race[race_group]['patients'] += 1
        stats_race[race_group]['admissions'] += num_adm
        stats_race[race_group]['readmission_visits'] += readmission_visits
        if has_readmission:
            stats_race[race_group]['readmission_patients'] += 1

    print("Gender Statistics:")
    for k, v in stats_gender.items():
        print(f"{k}: {v}")

    print("\nAge Statistics:")
    for k, v in stats_age.items():
        print(f"{k}: {v}")

    print("\nRace Statistics:")
    for k, v in stats_race.items():
        print(f"{k}: {v}")

    # Heart failure records in different groups (age, gender, race) -- + admission_codes
    print('\n', admission_codes[125157])

    stats_gender = {
        'F': {'patients': 0, 'admissions': 0, 'hf_admissions': 0, 'hf_patients': 0},
        'M': {'patients': 0, 'admissions': 0, 'hf_admissions': 0, 'hf_patients': 0}
    }

    stats_age = {
        '<=60': {'patients': 0, 'admissions': 0, 'hf_admissions': 0, 'hf_patients': 0},
        '>60': {'patients': 0, 'admissions': 0, 'hf_admissions': 0, 'hf_patients': 0}
    }

    stats_race = {
        'White': {'patients': 0, 'admissions': 0, 'hf_admissions': 0, 'hf_patients': 0},
        'Non-White': {'patients': 0, 'admissions': 0, 'hf_admissions': 0, 'hf_patients': 0}
    }

    for pid, info in patient_info.items():
        gender = info.get('GENDER')
        age_group = '<=60' if info.get('AGE', 0) <= 60 else '>60'
        race_group = 'White' if info.get('ethnicity', '').upper() == 'WHITE' else 'Non-White'

        admissions = patient_admission.get(pid, [])
        num_adm = len(admissions)
        hf_adm_count = 0
        hf_flag = False
        for adm in admissions:
            adm_id = adm['admission_id']
            codes = admission_codes.get(adm_id, [])
            if has_heart_failure(codes):
                hf_adm_count += 1
                hf_flag = True

        if gender in stats_gender:
            stats_gender[gender]['patients'] += 1
            stats_gender[gender]['admissions'] += num_adm
            stats_gender[gender]['hf_admissions'] += hf_adm_count
            if hf_flag:
                stats_gender[gender]['hf_patients'] += 1

        stats_age[age_group]['patients'] += 1
        stats_age[age_group]['admissions'] += num_adm
        stats_age[age_group]['hf_admissions'] += hf_adm_count
        if hf_flag:
            stats_age[age_group]['hf_patients'] += 1

        stats_race[race_group]['patients'] += 1
        stats_race[race_group]['admissions'] += num_adm
        stats_race[race_group]['hf_admissions'] += hf_adm_count
        if hf_flag:
            stats_race[race_group]['hf_patients'] += 1

    print("Gender:")
    for k, v in stats_gender.items():
        print(f"{k}: {v}")

    print("\nAge:")
    for k, v in stats_age.items():
        print(f"{k}: {v}")

    print("\nRace:")
    for k, v in stats_race.items():
        print(f"{k}: {v}")
