import argparse
from pathlib import Path
import pandas as pd


def analyze_patient_data(psv_dir: str, outcomes_dir: str, threshold_days: int = 3, label_column: str = "newlabel"):
    results = {
        "length_of_stay": {},
        "mapping_check": {
            "psv_ids_not_in_outcomes": [],
            "outcome_ids_not_in_psv": [],
            "is_one_to_one": False,
        },
    }

    psv_path = Path(psv_dir)
    if not psv_path.is_dir():
        print(f"Error: PSV directory not found at '{psv_dir}'")
        return None

    try:
        psv_ids = {int(p.stem) for p in psv_path.glob("*.psv")}
        print(f"Found {len(psv_ids)} unique patient IDs in '{psv_dir}'.")
    except ValueError:
        print(
            "Error: Could not process filenames. Ensure all .psv filenames are numbers (e.g., '12345.psv')."
        )
        return None

    outcomes_path = Path(outcomes_dir)
    if not outcomes_path.exists():
        print(f"Error: Outcomes path not found at '{outcomes_dir}'")
        return None

    outcome_files = []
    if outcomes_path.is_dir():
        # Support both split files and a single combined file
        outcome_files = list(outcomes_path.glob("Outcomes-*.txt"))
        if not outcome_files:
            combined = outcomes_path / "Outcomes.txt"
            if combined.exists():
                outcome_files = [combined]
    elif outcomes_path.is_file():
        outcome_files = [outcomes_path]

    if not outcome_files:
        print(f"Error: No outcomes file(s) found at '{outcomes_dir}'. Expected 'Outcomes-*.txt' or 'Outcomes.txt'.")
        return None

    all_outcomes_df = pd.concat([pd.read_csv(f) for f in outcome_files], ignore_index=True)
    all_outcomes_df.drop_duplicates(subset="RecordID", keep="first", inplace=True)
    outcome_ids = set(all_outcomes_df["RecordID"])
    print(f"Found {len(outcome_ids)} unique patient records in the outcome files.")

    psv_not_in_outcomes = psv_ids - outcome_ids
    outcomes_not_in_psv = outcome_ids - psv_ids
    results["mapping_check"]["psv_ids_not_in_outcomes"] = sorted(list(psv_not_in_outcomes))
    results["mapping_check"]["outcome_ids_not_in_psv"] = sorted(list(outcomes_not_in_psv))
    if not psv_not_in_outcomes and not outcomes_not_in_psv:
        results["mapping_check"]["is_one_to_one"] = True

    outcomes_indexed_df = all_outcomes_df.set_index("RecordID")
    matching_ids = psv_ids.intersection(outcome_ids)

    for record_id in sorted(list(matching_ids)):
        psv_file_path = psv_path / f"{record_id}.psv"
        try:
            los_value = outcomes_indexed_df.loc[record_id, "Length_of_stay"]
            results["length_of_stay"][record_id] = int(los_value)

            new_label = 1 if int(los_value) > int(threshold_days) else 0

            patient_df = pd.read_csv(psv_file_path, sep="|")
            patient_df[label_column] = new_label
            patient_df.to_csv(psv_file_path, sep="|", index=False)
            print(f"Updated {psv_file_path} with {label_column}: {new_label}")
        except KeyError:
            print(
                f"Warning: Could not find RecordID {record_id} in outcomes after all."
            )
            continue
        except (ValueError, TypeError):
            print(
                f"Warning: Could not convert Length_of_stay for RecordID {record_id} to a number. Value was '{los_value}'."
            )
        except FileNotFoundError:
            print(
                f"Warning: Could not find PSV file for RecordID {record_id} at '{psv_file_path}'"
            )

    return results


def main():
    parser = argparse.ArgumentParser(
        description="Write LoS-derived labels into PSV files based on outcomes file(s)."
    )
    parser.add_argument(
        "--psv_dir",
        type=str,
        required=True,
        help="Directory with patient .psv files (filenames must be numeric RecordIDs)",
    )
    parser.add_argument(
        "--outcomes_dir",
        type=str,
        required=True,
        help="Path to outcomes directory (with Outcomes-*.txt or Outcomes.txt) or a single outcomes file",
    )
    parser.add_argument(
        "--threshold",
        type=int,
        default=3,
        help="Threshold in days; label=1 if LOS > threshold (default: 3)",
    )
    parser.add_argument(
        "--label_column",
        type=str,
        default="newlabel",
        help="Name of the output label column to write into PSV files",
    )

    args = parser.parse_args()

    try:
        results = analyze_patient_data(
            psv_dir=args.psv_dir,
            outcomes_dir=args.outcomes_dir,
            threshold_days=args.threshold,
            label_column=args.label_column,
        )
        if results is None:
            return

        mapping_info = results["mapping_check"]
        if mapping_info["is_one_to_one"]:
            print("Success: one-to-one between PSV files and outcomes RecordIDs.")
        else:
            print("Warning: The mapping is NOT one-to-one.")
            if mapping_info["psv_ids_not_in_outcomes"]:
                print(
                    f"  - PSV files without a matching outcome record: {mapping_info['psv_ids_not_in_outcomes']}"
                )
            if mapping_info["outcome_ids_not_in_psv"]:
                print(
                    f"  - Outcome records without a matching PSV file: {mapping_info['outcome_ids_not_in_psv']}"
                )

        los_data = results["length_of_stay"]
        if los_data:
            report_df = pd.DataFrame(list(los_data.items()), columns=["RecordID", "Length_of_stay"])
            print(report_df.to_string(index=False))
        else:
            print("No matching records found to extract data from.")
    except Exception as e:
        print(f"Error: {e}")


if __name__ == "__main__":
    main()


