import pandas as pd
from PIL import Image
from tqdm import tqdm

from glob import glob
import os
from pathlib import Path
import json

from transformers import AutoModelForCausalLM, AutoProcessor
import torch

def get_data(subject_id, seq_nr, synth_prior_find):
    """
    Retrieves all frontal and lateral chest X-ray images for a given subject ID and sequence number.
    If seq_nr > 1.0, also retrieves images and reports from the prior study.

    Args:
        subject_id (str): The subject ID.
        seq_nr (float): The sequence number.
        synth_prior_find (str): The synthetic prior report (can be None if working with real prior reports).

    Returns:
        dict: A dictionary containing:
            - "frontal" (list of PIL images): List of current frontal chest X-rays.
            - "lateral" (list of PIL images): List of current lateral chest X-rays (if available).
            - "prior_frontal" (list of PIL images): List of prior frontal chest X-rays (if available).
            - "indication" (str): Clinical indication for the X-ray study (history).
            - "technique" (str): Imaging technique used.
            - "comparison" (str): Information about comparison to prior studies (if available). 
            - "prior_report" (str): Report text from the prior study (if available).
    """

    # Get the study for the given subject_id and sequence number
    study = df_report[(df_report["subject_id"] == subject_id) & (df_report["sequence"] == seq_nr)]
    
    if study.empty:
        print(f"No study found for subject {subject_id} with sequence {seq_nr}")
        return {
            "frontal": [], "lateral": [], "indication": "N/A",
            "technique": "N/A", "comparison": "N/A", "phrase": "N/A",
            "prior_frontal": [], "prior_find": "N/A", "prior_impr": "N/A"
        }
    
    # Extract indications (history)
    study_id = study["study_id"].unique()[0]
    hist = study[study["section"] == "hist"]["report"].to_numpy()
    hist = "N/A" if len(hist) == 0 else hist[0]

    # Construct image path
    fold = subject_id[:3]  # First 3 characters of subject_id for folder structure
    img_path = f"{root_img_path}/{fold}/{subject_id}/{study_id}"
    image_files = sorted(glob(os.path.join(img_path, "*.jpg")))  # Get sorted list of images

    # Initialize lists for storing images
    frontal_images = []
    lateral_images = []
    # technique = None

    # Iterate through images to find frontal and lateral views
    for image_file in image_files:
        image_index = os.path.basename(image_file).replace(".jpg", "")  # Extract index

        # Retrieve the ViewPosition from df_meta
        view_position = df_meta.loc[image_index, "ViewPosition"] if image_index in df_meta.index else None

        # Check and store frontal images
        if view_position in ["PA", "AP", "AP AXIAL", "PA LLD", "AP LLD", "AP RLD", "PA RLD"]:
            frontal_images.append(Image.open(image_file))

        # Check and store lateral images
        elif view_position in ["LATERAL", "LL", "XTABLE LATERAL", "SWIMMERS"]:
            lateral_images.append(Image.open(image_file))

        # if technique is None:
            # technique = df_meta.loc[image_index, "ProcedureCodeSequence_CodeMeaning"]

    # if no frontal images are found, try to use procedure description
    if len(frontal_images) == 0: 
        for image_file in image_files: 
            image_index = os.path.basename(image_file).replace(".jpg", "")  # Extract index
            procedure_code = df_meta.loc[image_index, "ProcedureCodeSequence_CodeMeaning"] if image_index in df_meta.index else None
            if procedure_code == "CHEST (PORTABLE AP)": 
                frontal_images.append(Image.open(image_file))

    # Initialize prior study data
    prior_frontal_images = []
    prior_report = None
    prior_find = "N/A"
    prior_impr = "N/A"
    prior_hist = "N/A"
    comparison = "N/A"

    # If seq_nr > 1.0, get prior study (seq_nr - 1)
    if seq_nr > 1.0:
        prior_study = df_report[(df_report["subject_id"] == subject_id) & 
                                (df_report["sequence"] == seq_nr - 1)]
        
        if not prior_study.empty:
            prior_study_id = prior_study["study_id"].unique()[0]

            # if no synthetic prior findings, get real prior findings and impression
            if synth_prior_find is None: 
                prior_find = prior_study[prior_study["section"] == "find"]["report"].to_numpy()
                prior_find = "N/A" if len(prior_find) == 0 else prior_find[0]

                prior_impr = prior_study[prior_study["section"] == "impr"]["report"].to_numpy()
                prior_impr = "N/A" if len(prior_impr) == 0 else prior_impr[0]

            # always get real prior indication
            prior_hist = prior_study[prior_study["section"] == "hist"]["report"].to_numpy()
            prior_hist = "N/A" if len(prior_hist) == 0 else prior_hist[0]

            # Combine prior study sections into one report (follow example in MAIRA-2 paper, Figure E12)
            if synth_prior_find is None: # use real prior findings
                if seq_nr == 2.0:
                    prior_report = f"INDICATION: {prior_hist}\n\nCOMPARISON: None.\n\nFINDINGS: {prior_find}\n\nIMPRESSION: {prior_impr}" # prior report had no comparison (first in sequence)
                else: 
                    prior_report = f"INDICATION: {prior_hist}\n\nCOMPARISON: Chest radiograph dated _.\n\nFINDINGS: {prior_find}\n\nIMPRESSION: {prior_impr}" # prior report had comparison
            else: # use synthetic prior findings -> no impression section!
                if seq_nr == 2.0:
                    prior_report = f"INDICATION: {prior_hist}\n\nCOMPARISON: None.\n\nFINDINGS: {synth_prior_find}" # prior report had no comparison (first in sequence)
                else: 
                    prior_report = f"INDICATION: {prior_hist}\n\nCOMPARISON: Chest radiograph dated _.\n\nFINDINGS: {synth_prior_find}" # prior report had comparison

            # Construct prior image path
            prior_img_path = f"{root_img_path}/{fold}/{subject_id}/{prior_study_id}"
            prior_image_files = sorted(glob(os.path.join(prior_img_path, "*.jpg")))

            # Find frontal image from prior study
            for prior_image_file in prior_image_files:
                prior_image_index = os.path.basename(prior_image_file).replace(".jpg", "")  
                prior_view_position = df_meta.loc[prior_image_index, "ViewPosition"] if prior_image_index in df_meta.index else None

                if prior_view_position in ["PA", "AP", "AP AXIAL", "PA LLD", "AP LLD", "AP RLD", "PA RLD"]:
                    prior_frontal_images.append(Image.open(prior_image_file))

            # Comparison statement should have anonymized date identifier
            comparison = "Chest radiograph dated _."

            # if no frontal images are found, try to use procedure description
            if len(prior_frontal_images) == 0: 
                for prior_image_file in prior_image_files: 
                    prior_image_index = os.path.basename(prior_image_file).replace(".jpg", "")  # Extract index
                    procedure_code = df_meta.loc[prior_image_index, "ProcedureCodeSequence_CodeMeaning"] if prior_image_index in df_meta.index else None
                    if procedure_code == "CHEST (PORTABLE AP)": 
                        prior_frontal_images.append(Image.open(prior_image_file))

    return {
        "frontal": frontal_images,
        "lateral": lateral_images,
        "indication": hist,
        "technique": "N/A", # no technique was extracted for gold set, so just leave it empty
        "comparison": comparison, 
        "prior_frontal": prior_frontal_images,
        "prior_report": prior_report 
    }

def prompt_maira(data): 

    """
    Generates a radiology report based on provided image data and previous report and image data (if available).

    Args:
        data (dict): A dictionary containing the following keys:
            - "frontal" (list of PIL images): List of current frontal chest X-rays.
            - "lateral" (list of PIL images): List of current lateral chest X-rays (if available).
            - "prior_frontal" (list of PIL images): List of prior frontal chest X-rays (if available).
            - "indication" (str): Clinical indication for the X-ray study (history).
            - "technique" (str): Imaging technique used.
            - "comparison" (str): Information about comparison to prior studies (if available). 
            - "prior_report" (str): Report text from the prior study (if available).

    Returns:
        str: A generated radiology report.
    """

    processed_inputs = processor.format_and_preprocess_reporting_input(
        current_frontal=data["frontal"][0], # just select the first frontal
        current_lateral=data["lateral"][0] if len(data["lateral"]) > 0 else None, # just select the first lateral
        prior_frontal=data["prior_frontal"][0] if len(data["prior_frontal"]) > 0 else None, # just select the first frontal
        indication=data["indication"],
        technique=data["technique"],
        comparison=data["comparison"],
        prior_report=data["prior_report"],
        return_tensors="pt",
        get_grounding=False,  # Generate a non-grounded report
    )

    # processed_inputs = processed_inputs.to(device)
    processed_inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                    for k, v in processed_inputs.items()}

    with torch.no_grad():
        output_decoding = model.generate(
            **processed_inputs,
            max_new_tokens=300,  # Set to 450 for grounded reporting
            use_cache=True,
        )
    prompt_length = processed_inputs["input_ids"].shape[-1]
    decoded_text = processor.decode(output_decoding[0][prompt_length:], skip_special_tokens=True)
    decoded_text = decoded_text.lstrip()  # Findings generation completions have a single leading space
    pred_report = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)

    return pred_report


if __name__ == "__main__": 

    # TODO: set parameters
    TOKEN = "set-your-own-token" # TODO: please set your own huggingface token for MAIRA access
    OUTPUT_DIR = "set-your-output-directory-path" # TODO: please set the path to your output directory
    SYNTH_PRIOR = False # TODO: specify whether to generate reports in standard setting (SYNTH_PRIOR = False) or cascaded setting (SYNTH_PRIOR = True)

    # get gold set of patients and path to images
    # TODO: set image path
    # TODO: set gold dataset path
    root_img_path = "/nfs_data_storage/mimic-cxr-jpg/physionet.org/files/mimic-cxr-jpg/2.0.0/files"
    gold_path = "../gold_revise/0.gold_dataset_v2.csv"
    df_report = pd.read_csv(gold_path)
    df_report = df_report.drop_duplicates(subset=["subject_id", "study_id", "sequence", "section", "report"])[["subject_id", "study_id", "sequence", "section", "report"]]

    # get image metadata (for determining frontal/lateral view)
    # TODO: set metadata path
    metadata_path = "/nfs_data_storage/mimic-cxr-jpg/physionet.org/files/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-metadata.csv.gz"
    df_meta = pd.read_csv(metadata_path, compression="gzip")
    df_meta.set_index("dicom_id", inplace=True)

    # load in MAIRA model and processor
    model = AutoModelForCausalLM.from_pretrained("microsoft/maira-2", trust_remote_code=True, token=TOKEN)
    processor = AutoProcessor.from_pretrained("microsoft/maira-2", trust_remote_code=True, token=TOKEN)

    # put model on GPU
    device = torch.device("cuda")
    model = model.eval()
    model = model.to(device)

    # if SYNTH_PRIOR == True: pass along the synthetic prior report that was just generated, instead of the true prior report
    if SYNTH_PRIOR: 
        file_name = "gen_reports_cascade.ndjson"
    else: 
        file_name = "gen_reports_true_prior.ndjson"

    # generate one report per study
    subject_id_list = df_report["subject_id"].unique()
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    file_name = "gen_reports_cascade.ndjson" if SYNTH_PRIOR else "gen_reports_true_prior.ndjson"
    file_path = os.path.join(OUTPUT_DIR, file_name)

    with open(file_path, 'w', encoding='utf-8') as fp:

        for subj in tqdm(subject_id_list, desc="Processing patients"):
            seq_nrs = sorted(df_report[df_report["subject_id"] == subj]["sequence"].unique())
            synthetic_reports = {}

            for nr in tqdm(seq_nrs, desc=f"Patient {subj} sequences", leave=False):
                print(f"generating report for patient {subj} sequence {nr}")
                
                # use synthetic prior report that was just generated
                if SYNTH_PRIOR and nr > 1.0:

                    # get prior generated report
                    prior_find = synthetic_reports[nr-1]

                    # if prior findings empty (i.e., because prior had no frontal), pass N/A
                    if len(prior_find) == 0: 
                        prior_find = "N/A"

                    # get data object for this patient, pass prior find
                    data = get_data(subj, nr, prior_find)

                else: 
                    # do not pass prior finding
                    data = get_data(subj, nr, None)

                # call maira
                if len(data["frontal"]) != 0:
                    pred_report = prompt_maira(data)
                else: 
                    print(f"no frontal image for patient {subj} sequence {nr}! skipping generation.")
                    pred_report = ""

                # store synthetic report
                synthetic_reports[nr] = pred_report

                # construct JSON entry
                report_entry = {
                    "subject_id": subj,
                    "sequence": nr,
                    "report": pred_report
                }
                
                fp.write(json.dumps(report_entry) + "\n")
                fp.flush()
