# This script extracts the conclusion section from MIMIC-CXR reports
# It outputs them into individual files with at most 10,000 reports.
import sys
import os
import argparse
import csv
from pathlib import Path

from tqdm import tqdm
import ast
import pandas as pd
import re


# local folder import
import section_parser as sp


parser = argparse.ArgumentParser()
parser.add_argument('--reports_path',
                    default=None,
                    help=('Path to directory with radiology reports. If omitted, '
                          'will try known defaults under /nfs_data_storage and /home/data_storage'))
parser.add_argument('--output_path',
                    default = './processed_reports',
                    help='Path to output CSV files.')
parser.add_argument('--no_split', action='store_true',
                    help='Do not output batched CSV files.')
parser.add_argument('--split_path', default=None,
                    help='Override path to mimic-cxr split file (.csv or .csv.gz).')
parser.add_argument('--metadata_path', default=None,
                    help='Override path to mimic-cxr metadata file (.csv or .csv.gz).')


def list_rindex(l, s):
    """Helper function: *last* matching element in a list"""
    return len(l) - l[-1::-1].index(s) - 1

def clean(report):
    """Clean the report text."""
    if not isinstance(report, str):
        return ""

    # Remove "__" at the beginning followed by a single letter (e.g. F or M)
    report = re.sub(r"^__+[a-zA-Z] ", "", report)
    
    # Consolidate multiple underscores into a single space
    report = re.sub("__+", " ", report)
    report = report.replace(" // ", ", ")

    # Replace multiple dots with a single dot
    report = re.sub("\.\.+", ".", report)
    
    # Replace multiple spaces with a single space
    report = re.sub(" +", " ", report)
    
    # Remove new lines and correct numbered lists
    cleaned_report = report.replace("\n", " ").replace("1. ", "")
    
    # Remove numbering patterns like ". 2. ", ". 3. " and so on
    cleaned_report = re.sub(r"\. [2-5]\. ", ". ", cleaned_report)
    cleaned_report = re.sub(r" [2-5]\. ", ". ", cleaned_report)
    
    cleaned_report = cleaned_report.strip()

    # Process each sentence in the report
    sentences = cleaned_report.split(". ")
    cleaned_sentences = []

    for sent in sentences:
        # Automatically fix spaces before commas
        sent = re.sub(r"\s,", ",", sent)
        
        # Fix spaces around hyphens like "Post -operative"
        sent = re.sub(r"\b -\b", "-", sent)

        # Replace " // " with a comma and space
        sent = sent.replace(" // ", ", ")

        # Remove certain punctuations and special characters from the sentence
        sent = re.sub("[.?;*!%^&_+()\[\]{}]", "", sent.replace('"', "")
           .replace("- ", "-")
           .replace("\\", "")
           .replace("'", "")
           .strip())
        
        # Capitalize the first letter of the sentence and add to the list
        if sent:
            sent = sent[0].capitalize() + sent[1:]
            cleaned_sentences.append(sent + ".")

    # Combine the cleaned sentences into a single string and remove any double spaces
    result = " ".join(cleaned_sentences)
    result = re.sub(" +", " ", result).strip()

    return result

def main(args):
    args = parser.parse_args(args)

    # Resolve reports_path: use provided, otherwise try known defaults
    if args.reports_path is not None:
        reports_path = Path(args.reports_path)
    else:
        candidate_report_paths = [
            '/nfs_data_storage/mimic-cxr-jpg/2.0.0/reports',
            '/nfs_data_storage/mimic-cxr-dcm/reports',
            '/home/data_storage/mimic-cxr-jpg/2.0.0/reports',
        ]
        reports_path = None
        for cand in candidate_report_paths:
            if os.path.isdir(cand):
                reports_path = Path(cand)
                break
        if reports_path is None:
            raise FileNotFoundError(
                'Could not locate reports directory. Provide --reports_path or ensure one of the defaults exists: '
                + ', '.join(candidate_report_paths)
            )
    output_path = Path(args.output_path)

    if not output_path.exists():
        output_path.mkdir()

    # not all reports can be automatically sectioned
    # we load in some dictionaries which have manually determined sections
    custom_section_names, custom_indices = sp.custom_mimic_cxr_rules()

    # get all higher up folders (p00, p01, etc)
    p_grp_folders = os.listdir(reports_path)
    p_grp_folders = [p for p in p_grp_folders
                     if p.startswith('p') and len(p) == 3]
    
    # small set
    # p_grp_folders = p_grp_folders[:2]
    p_grp_folders.sort()

    # patient_studies will hold the text for use in NLP labeling
    patient_studies = []

    # study_sections will have an element for each study
    # this element will be a list, each element having text for a specific section
    whole_reports, study_sections = [], []
    for p_grp in p_grp_folders:
        # get patient folders, usually around ~6k per group folder
        cxr_path = reports_path / p_grp
        p_folders = os.listdir(cxr_path)
        p_folders = [p for p in p_folders if p.startswith('p')]
        p_folders.sort()

        # For each patient in this grouping folder
        # print(p_grp)
        for p in tqdm(p_folders):
            patient_path = cxr_path / p

            # get the filename for all their free-text reports
            studies = os.listdir(patient_path)
            studies = [s for s in studies
                       if s.endswith('.txt') and s.startswith('s')]

            for s in studies:
                # load in the free-text report
                with open(patient_path / s, 'r') as fp:
                    text = ''.join(fp.readlines())

                # get study string name without the txt extension
                s_stem = s[0:-4]

                # custom rules for some poorly formatted reports
                if s_stem in custom_indices:
                    idx = custom_indices[s_stem]
                    patient_studies.append([s_stem, text[idx[0]:idx[1]]])
                    continue

                # split text into sections
                sections, section_names, section_idx = sp.section_text(
                    text
                )
                # print("sections", sections)
                # print("s_stem", s_stem)
                # print("section_names", section_names)
                
                # check to see if this has mis-named sections
                # e.g. sometimes the impression is in the comparison section
                if s_stem in custom_section_names:
                    sn = custom_section_names[s_stem]
                    idx = list_rindex(section_names, sn)
                    patient_studies.append([s_stem, sections[idx].strip()])
                    continue

                # grab the *last* section with the given title
                # prioritizes impression > findings, etc.

                # "last_paragraph" is text up to the end of the report
                # many reports are simple, and have a single section
                # header followed by a few paragraphs
                # these paragraphs are grouped into section "last_paragraph"

                # note also comparison seems unusual but if no other sections
                # exist the radiologist has usually written the report
                # in the comparison section
                
                # print("section_names", section_names)
                
                
                idx = -1
                for sn in ('impression', 'preamble', 'wet read', 'comment', 'addendum', 'comparison', 'history', 'notification', 'examination', 'date', 'indication', 'findings', 'technique', 'recommendations', 'last_paragraph'):
                    if sn in section_names:
                        idx = list_rindex(section_names, sn)
                        break

                if idx == -1:
                    # we didn't find any sections we can use :(
                    patient_studies.append([s_stem, ''])
                    print(f'no impression/findings: {patient_path / s}')
                else:
                    # store the text of the conclusion section
                    patient_studies.append([s_stem, sections[idx].strip()])

                study_sectioned = [s_stem]
                for sn in ('impression', 'preamble', 'wet read', 'comment', 'addendum', 'comparison', 'history', 'notification', 'examination', 'date', 'indication', 'findings', 'technique', 'recommendations', 'last_paragraph'):
                    if sn in section_names:
                        idx = list_rindex(section_names, sn)
                        study_sectioned.append(sections[idx].strip())
                    else:
                        study_sectioned.append(None)
                study_sections.append(study_sectioned)
                whole_reports.append(text)
                
    # write distinct files to facilitate modular processing
    if len(patient_studies) > 0:
        # write out a single CSV with the sections
        with open(output_path / 'mimic_cxr_sectioned.csv', 'w') as fp:
            csvwriter = csv.writer(fp)
            # write header
            csvwriter.writerow(['reports', 'study','impression', 'preamble', 'wet read', 'comment', 'addendum', 'comparison', 'history', 'notification', 'examination', 'date', 'indication', 'findings', 'technique', 'recommendations', 'last_paragraph'])
            for whole_report, row in zip(whole_reports, study_sections):
                csvwriter.writerow([whole_report] + row)

        if args.no_split:
            # write all the reports out to a single file
            with open(output_path / f'mimic_cxr_sections.csv', 'w') as fp:
                csvwriter = csv.writer(fp)
                for row in patient_studies:
                    csvwriter.writerow(row)
        else:
            # write ~22 files with ~10k reports each
            n = 0
            jmp = 10000

            while n < len(patient_studies):
                n_fn = n // jmp
                with open(output_path / f'mimic_cxr_{n_fn:02d}.csv', 'w') as fp:
                    csvwriter = csv.writer(fp)
                    for row in patient_studies[n:n+jmp]:
                        csvwriter.writerow(row)
                n += jmp


def filter_sectioned_file(sectioned_file_path, output_path, split_path=None, metadata_path=None):
    every_report = pd.read_csv(sectioned_file_path)
    # Ensure study_id exists and is normalized to 's' + numeric string BEFORE subsetting
    if 'study_id' in every_report.columns:
        every_report['study_id'] = every_report['study_id'].astype(str)
    elif 'study' in every_report.columns:
        every_report['study_id'] = every_report['study'].astype(str)
    else:
        raise KeyError("Neither 'study_id' nor 'study' found in sectioned file")
    mask_not_s = ~every_report['study_id'].astype(str).str.startswith('s', na=False)
    if mask_not_s.any():
        every_report.loc[mask_not_s, 'study_id'] = 's' + every_report.loc[mask_not_s, 'study_id'].astype(str)
    # Now select only the needed columns (without 'study')
    every_report = every_report[['reports', 'study_id', 'history', 'indication', 'findings', 'impression', 'last_paragraph', 'recommendations']]

    # Load split file
    if split_path is not None:
        split = pd.read_csv(split_path)
    else:
        try:
            split = pd.read_csv('/nfs_data_storage/mimic-cxr-jpg/physionet.org/files/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-split.csv.gz')
        except Exception:
            split = pd.read_csv('/home/data_storage/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-split.csv')

    # Load metadata file
    if metadata_path is not None:
        meta = pd.read_csv(metadata_path)
    else:
        try:
            meta = pd.read_csv('/nfs_data_storage/mimic-cxr-jpg/physionet.org/files/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-metadata.csv.gz')
        except Exception:
            meta = pd.read_csv('/home/data_storage/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-metadata.csv')
    
    unique_std = meta.drop_duplicates('study_id')
    seq_meta = unique_std.sort_values(by=['subject_id', 'StudyDate', 'StudyTime'])
    seq_meta['temp_sequence'] = seq_meta.groupby('subject_id').cumcount() + 1

    seq_meta['sequence'] = seq_meta.groupby(['subject_id', 'study_id'])['temp_sequence'].transform('first')
    # seq_meta[['sequence', 'subject_id', 'study_id', 'StudyDate', 'StudyTime']].head(20)
    seq_meta = seq_meta[['subject_id', 'study_id', 'StudyDate', 'StudyTime', 'sequence']]

    seq_meta = seq_meta.merge(split.drop_duplicates('study_id')[['subject_id', 'study_id', 'split']], on=['subject_id', 'study_id'], how='left')
    # Normalize seq_meta.study_id to prefixed string
    seq_meta['study_id'] = 's' + seq_meta['study_id'].astype(str)
    
    # Normalize subject_id to prefixed string if not already prefixed
    seq_meta['subject_id'] = seq_meta['subject_id'].astype(str)
    mask_not_p = ~seq_meta['subject_id'].str.startswith('p', na=False)
    if mask_not_p.any():
        seq_meta.loc[mask_not_p, 'subject_id'] = 'p' + seq_meta.loc[mask_not_p, 'subject_id']

    # No token-based parsing or length computation; operate directly on raw text

    # Output schema
    cols_for_output = [
        'subject_id', 'study_id', 'sequence', 'split', 'StudyDate', 'StudyTime',
        'section', 'report'
    ]

    # Single merge for all rows
    merged = seq_meta.merge(every_report, on=['study_id'], how='inner')
    print('Total merged rows:', len(merged))

    # Normalize types
    if 'study_id' in merged.columns:
        try:
            merged['study_id'] = merged['study_id'].astype(str)
        except Exception:
            pass

    # Initialize applied columns
    merged['findings_applied'] = ''
    merged['impression_applied'] = ''
    merged['history_indication_applied'] = ''

    # Text presence masks (non-empty after strip)
    mask_find = merged['findings'].fillna('').astype(str).str.strip().str.len() > 0
    mask_imp = merged['impression'].fillna('').astype(str).str.strip().str.len() > 0
    mask_hist = merged['history'].fillna('').astype(str).str.strip().str.len() > 0
    mask_indi = merged['indication'].fillna('').astype(str).str.strip().str.len() > 0

    # Apply cleaning with masks
    if mask_find.any():
        merged.loc[mask_find, 'findings_applied'] = merged.loc[mask_find, 'findings'].map(clean)
    if mask_imp.any():
        merged.loc[mask_imp, 'impression_applied'] = merged.loc[mask_imp, 'impression'].map(clean)
    if mask_hist.any():
        merged.loc[mask_hist, 'history_indication_applied'] = merged.loc[mask_hist, 'history'].map(clean)
    if mask_indi.any():
        merged['history_indication_applied'] = merged['history_indication_applied'].fillna('')
        merged.loc[mask_indi, 'history_indication_applied'] = (
            merged.loc[mask_indi, 'history_indication_applied'].fillna('')
            + merged.loc[mask_indi, 'indication'].map(clean)
        )

    # Fallback: findings from last_paragraph when both empty
    both_empty = (
        merged['findings_applied'].fillna('').astype(str).str.len().eq(0)
        & merged['impression_applied'].fillna('').astype(str).str.len().eq(0)
        & merged['last_paragraph'].notna()
    )
    if both_empty.any():
        merged.loc[both_empty, 'findings_applied'] = merged.loc[both_empty, 'last_paragraph'].map(clean)

    # Build long-form
    id_cols_all = ['subject_id', 'study_id', 'sequence', 'split', 'StudyDate', 'StudyTime']
    id_cols = [c for c in id_cols_all if c in merged.columns]

    hist_df = merged[id_cols].copy()
    hist_df['section'] = 'hist'
    hist_df['report'] = merged['history_indication_applied']

    find_df = merged[id_cols].copy()
    find_df['section'] = 'find'
    find_df['report'] = merged['findings_applied']

    impr_df = merged[id_cols].copy()
    impr_df['section'] = 'impr'
    impr_df['report'] = merged['impression_applied']

    combined = pd.concat([hist_df, find_df, impr_df], ignore_index=True)
    combined['report'] = combined['report'].fillna('')
    combined = combined[combined['report'].astype(str).str.len() > 0]
    keep_cols = [c for c in cols_for_output if c in combined.columns]
    combined = combined[keep_cols].copy()

    # Write output
    if output_path:
        from pathlib import Path as _Path
        output_path_p = _Path(output_path)
        if output_path_p.parent and not output_path_p.parent.exists():
            output_path_p.parent.mkdir(parents=True, exist_ok=True)
        combined.to_csv(output_path_p, index=False)


if __name__ == '__main__':
    main(sys.argv[1:])
    # Parse arguments again to fetch optional paths for split and metadata
    _args = parser.parse_args(sys.argv[1:])
    filter_sectioned_file(
        './processed_reports/mimic_cxr_sectioned.csv',
        './processed_reports/mimic_cxr_reports.csv',
        split_path=_args.split_path,
        metadata_path=_args.metadata_path,
    )