import pandas as pd
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Process some notes.")
    parser.add_argument("--input_file", type=str, help="Path to the input file")
    parser.add_argument("--output_file", type=str, help="Path to the output file")
    return parser.parse_args()

def sectionize_notes(df):
    import re
    
    Onc_history = ['Summary of oncologic history:?','ONCOLOGY HISTORY','Oncology history','ONCOLOGIC HISTORY:', "Oncologic History, reviewed:"]
    
    Subjective =['SUBJECTIVE:','INTERVAL HISTORY:','Interval history:','Chief Complaint: ','returns to our GI Oncology practice',
            'Patient Identification and Oncology History','Interval History:', 'Subjective  *****','HPI      ',
            'HISTORY OF PRESENT ILLNESS:','History of Present Illness', 'HPI:','HPI ', 'Dear Dr.','Subjective: ','  Subjective   ',
            'presents for   Chief Complaint','Chief Complaint   Patient presents',
             '24 Hour Course','ID:', "FOLLOW-UP GASTROINTESTINAL MEDICAL ONCOLOGY VISIT", "UCSF Cancer Center GI Medical Oncology Program", "FOLLOW UP VISIT"]

    PMH = ['Past Medical History','Medical History     ']
    
    Meds = ['CURRENT MEDICATIONS:','Medications the patient states to be taking prior','Prescription Medications as of',
           'Current Medications','Current Outpatient Medications','Current Outpatient Prescriptions',
           'Prescriptions Prior to Admission','Current Outpatient Medications ']
    
    Allergies = ['ALLERGIES:','Allergen Reactions','Allergies as of']
    
    ROS = ['Review of Systems','present review of systems was reviewed and notable for the following:']
    
    Labs = ['LABORATORY RESULTS:','Laboratory data:','Lab data:','LABORATORY RESULTS']
    
    Path = ['Pathology:','PATHOLOGY']
    
    SH = ['SOCIAL HISTORY:','Social Documentation','    Social History     ','      Social History',
         'PERSONAL AND SOCIAL HISTORY','Social History     ']
    
    FH = ['FAMILY HISTORY:','Family History']
    
    PE = ['PHYSICAL EXAM:','Physical Exam:','    Physical Exam   ','OBJECTIVE ASSESSMENT']
    
    Imaging = ['IMAGING:','RADIOGRAPHIC AND PATHOLOGY RESULTS','Relevant Diagnostic Studies:']
    
    A_P = ['ASSESSMENT:?','ASSESSMENT AND PLAN','Assessment and plan:?','Assessment   Impression:?',
          'ASSESSMENT & PLAN','    Assessment  ','Impression and Recommendations:?','Impression and Plan:?','Impression and plan:?',
          'Assessment & Recommendations:?','Assessment and Plan',' IMPRESSION AND PLAN:?','IMPRESSION:?',
          'ASSESSMENT AND RECOMMENDATIONS:?', 'A\/p', 'A/p', 'A\/P', 'A\P','IMPRESSION AND PLAN', "Assessment    Impression: ", 
           "Impression and Recommendations:", "Assessment:    "]
    
    header_dict = {
        "Subjective": Subjective,
        "Onc_history": Onc_history,
        "PMH": PMH,
        'Meds':Meds,
        "Allergies":Allergies,
        'ROS':ROS,
        'Labs':Labs,
        'Path':Path,
        'SH':SH,
        'FH':FH,
        'PE':PE,
        'Imaging':Imaging,
        'A_P':A_P
    }

    def headers_to_pattern(header_dict):
        pattern_dict = {}
        for section_name, headers in header_dict.items():
            pattern = "|".join(map(re.escape, headers))
            pattern_dict[section_name] = pattern
        return pattern_dict
    # Combine all headers
    def find_headers(note, pattern_dict):
        positions = []
        for section_name, pattern in pattern_dict.items():
            for m in re.finditer(pattern, note):
                positions.append((m.start(), m.end(), m.group(), section_name))
        return sorted(positions, key=lambda x: x[0])  # sort by start position
    def split_note_by_headers(note, headers_positions):
        sections = {}
        found_A_P = False
        for i, (start, end, header, section_name) in enumerate(headers_positions):
            if found_A_P:
                break  # Ignore any further headers after A_P
            if section_name == "A_P":
                section_end = len(note)
                found_A_P = True
            else:
                section_end = headers_positions[i+1][0] if i+1 < len(headers_positions) else len(note)
            sections[section_name] = note[end:section_end].strip()
        return sections
    pattern_dict = headers_to_pattern(header_dict)
    rows = list()
    for i in range(len(df)):
        note = df.loc[i,'note_text']
        headers_positions = find_headers(note, pattern_dict)
        sections = split_note_by_headers(note, headers_positions)
        rows.append(sections)
        if (i % 10000)==0:
            print('finished note: '+str(i)+'out of '+str(len(df)))
    df_test = pd.DataFrame(rows)
    
    df_test.loc[:,'note_id'] = df['note_id'].tolist()
    df_test.loc[:,'note_text'] = df['note_text'].tolist()
    desired_columns = header_dict.keys()
    for col in desired_columns:
        if col not in df_test.columns:
            df_test[col] = ''
    
    df_test = df_test[['note_id', 
                       'note_text','Onc_history', 'Subjective', 'ROS','Allergies', 'Meds', 'PE', 'Imaging',
                       'Path', 'PMH', 'FH', 'SH', 'Labs',  'A_P']]


    return df_test


if __name__ == "__main__":
    args = parse_args()
    df = pd.read_csv(args.input_file)
    df = sectionize_notes(df)
    df.to_csv(args.output_file, index=False)
