import csv
import os
import random
from pathlib import Path

import nibabel as nib
from batchgenerators.utilities.file_and_folder_operations import load_json, save_json

from nnunetv2.dataset_conversion.Dataset027_ACDC import make_out_dirs
from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
from nnunetv2.paths import nnUNet_preprocessed


def read_csv(csv_file: str):
    patient_info = {}

    with open(csv_file) as csvfile:
        reader = csv.reader(csvfile)
        headers = next(reader)
        patient_index = headers.index("External code")
        ed_index = headers.index("ED")
        es_index = headers.index("ES")
        vendor_index = headers.index("Vendor")

        for row in reader:
            patient_info[row[patient_index]] = {
                "ed": int(row[ed_index]),
                "es": int(row[es_index]),
                "vendor": row[vendor_index],
            }

    return patient_info


# ------------------------------------------------------------------------------
# Conversion to nnUNet format
# ------------------------------------------------------------------------------
def convert_mnms(src_data_folder: Path, csv_file_name: str, dataset_id: int):
    out_dir, out_train_dir, out_labels_dir, out_test_dir = make_out_dirs(dataset_id, task_name="MNMs")
    patients_train = [f for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()]
    patients_test = [f for f in (src_data_folder / "Testing").iterdir() if f.is_dir()]

    patient_info = read_csv(str(src_data_folder / csv_file_name))

    save_cardiac_phases(patients_train, patient_info, out_train_dir, out_labels_dir)
    save_cardiac_phases(patients_test, patient_info, out_test_dir)

    # There are non-orthonormal direction cosines in the test and validation data.
    # Not sure if the data should be fixed, or we should skip the problematic data.
    # patients_val = [f for f in (src_data_folder / "Validation").iterdir() if f.is_dir()]
    # save_cardiac_phases(patients_val, patient_info, out_train_dir, out_labels_dir)

    generate_dataset_json(
        str(out_dir),
        channel_names={
            0: "cineMRI",
        },
        labels={"background": 0, "LVBP": 1, "LVM": 2, "RV": 3},
        file_ending=".nii.gz",
        num_training_cases=len(patients_train) * 2,  # 2 since we have ED and ES for each patient
    )


def save_cardiac_phases(
    patients: list[Path], patient_info: dict[str, dict[str, int]], out_dir: Path, labels_dir: Path = None
):
    for patient in patients:
        print(f"Processing patient: {patient.name}")

        image = nib.load(patient / f"{patient.name}_sa.nii.gz")
        ed_frame = patient_info[patient.name]["ed"]
        es_frame = patient_info[patient.name]["es"]

        save_extracted_nifti_slice(image, ed_frame=ed_frame, es_frame=es_frame, out_dir=out_dir, patient=patient)

        if labels_dir:
            label = nib.load(patient / f"{patient.name}_sa_gt.nii.gz")
            save_extracted_nifti_slice(label, ed_frame=ed_frame, es_frame=es_frame, out_dir=labels_dir, patient=patient)


def save_extracted_nifti_slice(image, ed_frame: int, es_frame: int, out_dir: Path, patient: Path):
    # Save only extracted diastole and systole slices from the 4D H x W x D x time volume.
    image_ed = nib.Nifti1Image(image.dataobj[..., ed_frame], image.affine)
    image_es = nib.Nifti1Image(image.dataobj[..., es_frame], image.affine)

    # Labels do not have modality identifiers. Labels always end with 'gt'.
    suffix = ".nii.gz" if image.get_filename().endswith("_gt.nii.gz") else "_0000.nii.gz"

    nib.save(image_ed, str(out_dir / f"{patient.name}_frame{ed_frame:02d}{suffix}"))
    nib.save(image_es, str(out_dir / f"{patient.name}_frame{es_frame:02d}{suffix}"))


# ------------------------------------------------------------------------------
# Create custom splits
# ------------------------------------------------------------------------------
def create_custom_splits(src_data_folder: Path, csv_file: str, dataset_id: int, num_val_patients: int = 25):
    existing_splits = os.path.join(nnUNet_preprocessed, f"Dataset{dataset_id}_MNMs", "splits_final.json")
    splits = load_json(existing_splits)

    patients_train = [f.name for f in (src_data_folder / "Training" / "Labeled").iterdir() if f.is_dir()]
    # Filter out any patients not in the training set
    patient_info = {
        patient: data
        for patient, data in read_csv(str(src_data_folder / csv_file)).items()
        if patient in patients_train
    }

    # Get train and validation patients for both vendors
    patients_a = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "A"]
    patients_b = [patient for patient, patient_data in patient_info.items() if patient_data["vendor"] == "B"]
    train_a, val_a = get_vendor_split(patients_a, num_val_patients)
    train_b, val_b = get_vendor_split(patients_b, num_val_patients)

    # Build filenames from corresponding patient frames
    train_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_a for frame in ["es", "ed"]]
    train_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in train_b for frame in ["es", "ed"]]
    train_a_mix_1, train_a_mix_2 = train_a[: len(train_a) // 2], train_a[len(train_a) // 2 :]
    train_b_mix_1, train_b_mix_2 = train_b[: len(train_b) // 2], train_b[len(train_b) // 2 :]
    val_a = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_a for frame in ["es", "ed"]]
    val_b = [f"{patient}_frame{patient_info[patient][frame]:02d}" for patient in val_b for frame in ["es", "ed"]]

    for train_set in [train_a, train_b, train_a_mix_1 + train_b_mix_1, train_a_mix_2 + train_b_mix_2]:
        # For each train set, we evaluate on A, B and (A + B) respectively
        # See table 3 from the original paper for more details.
        splits.append({"train": train_set, "val": val_a})
        splits.append({"train": train_set, "val": val_b})
        splits.append({"train": train_set, "val": val_a + val_b})

    save_json(splits, existing_splits)


def get_vendor_split(patients: list[str], num_val_patients: int):
    random.shuffle(patients)
    total_patients = len(patients)
    num_training_patients = total_patients - num_val_patients
    return patients[:num_training_patients], patients[num_training_patients:]


if __name__ == "__main__":
    import argparse

    class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
        pass

    parser = argparse.ArgumentParser(add_help=False, formatter_class=RawTextArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "-h",
        "--help",
        action="help",
        default=argparse.SUPPRESS,
        help="MNMs conversion utility helper. This script can be used to convert MNMs data into the expected nnUNet "
        "format. It can also be used to create additional custom splits, for explicitly training on combinations "
        "of vendors A and B (see `--custom-splits`).\n"
        "If you wish to generate the custom splits, run the following pipeline:\n\n"
        "(1) Run `Dataset114_MNMs -i <raw_Data_dir>\n"
        "(2) Run `nnUNetv2_plan_and_preprocess -d 114 --verify_dataset_integrity`\n"
        "(3) Start training, but stop after initial splits are created: `nnUNetv2_train 114 2d 0`\n"
        "(4) Re-run `Dataset114_MNMs`, with `-s True`.\n"
        "(5) Re-run training.\n",
    )
    parser.add_argument(
        "-i",
        "--input_folder",
        type=str,
        default="./data/M&Ms/OpenDataset/",
        help="The downloaded MNMs dataset dir. Should contain a csv file, as well as Training, Validation and Testing "
        "folders.",
    )
    parser.add_argument(
        "-c",
        "--csv_file_name",
        type=str,
        default="211230_M&Ms_Dataset_information_diagnosis_opendataset.csv",
        help="The csv file containing the dataset information.",
    ),
    parser.add_argument("-d", "--dataset_id", type=int, default=114, help="nnUNet Dataset ID.")
    parser.add_argument(
        "-s",
        "--custom_splits",
        type=bool,
        default=False,
        help="Whether to append custom splits for training and testing on different vendors. If True, will create "
        "splits for training on patients from vendors A, B or a mix of A and B. Splits are tested on a hold-out "
        "validation sets of patients from A, B or A and B combined. See section 2.4 and table 3 from "
        "https://arxiv.org/abs/2011.07592 for more info.",
    )

    args = parser.parse_args()
    args.input_folder = Path(args.input_folder)

    if args.custom_splits:
        print("Appending custom splits...")
        create_custom_splits(args.input_folder, args.csv_file_name, args.dataset_id)
    else:
        print("Converting...")
        convert_mnms(args.input_folder, args.csv_file_name, args.dataset_id)

    print("Done!")
