from utils import compute_metrics
from os import path as osp
import argparse
from queue import Queue
import re
from subprocess import call
import nibabel as nib
from tqdm import tqdm
import pprint
import pickle as pkl
import SimpleITK as sitk

results_queue = Queue()

def trimspace(str):
    return re.sub(' +', ' ', str).replace("\t", "")

def ants_transform_fn(fsegpath, msegpath, fixedid, movingid, output_name="output_seg.nii.gz", worker_id=0):
    global results_queue
    cmd = f"""antsApplyTransforms -d 3 -i {msegpath} -r {fsegpath} -t ./ANTs/output_{fixedid}_{movingid}1Warp.nii.gz -t ./ANTs/output_{fixedid}_{movingid}0GenericAffine.mat -n GenericLabel -o {output_name}"""
    warppath = f"./ANTs/output_{fixedid}_{movingid}1Warp.nii.gz"
    cmd = trimspace(cmd)
    print(cmd)
    call(cmd.split(" "))
    # now load the two volumes and compute scores
    fixedlab = nib.load(fsegpath).get_fdata().squeeze()
    movedlab = nib.load(output_name).get_fdata().squeeze()
    results_queue.put((fixedid, movingid, compute_metrics(fixedlab, movedlab, warppath, 'ants')))

def greedy_transform_fn(fsegpath, msegpath, fixedid, movingid, output_name="output_seg.nii.gz", worker_id=0):
    global results_queue
    cmd = trimspace(f"""greedy -d 3 -ri LABEL 0.2vox -rf {fsegpath} -rm {msegpath} {output_name} -r ./Greedy/output_{fixedid}_{movingid}_warp.nii.gz""")
    warppath = f"./Greedy/output_{fixedid}_{movingid}_warp.nii.gz"
    print(cmd)
    call(cmd.split())
    # load volumes and put into results
    fixedlab = nib.load(fsegpath).get_fdata().squeeze()
    movedlab = nib.load(output_name).get_fdata().squeeze()
    results_queue.put((fixedid, movingid, compute_metrics(fixedlab, movedlab, warppath, 'greedy')))

def nifty_transform_fn(fsegpath, msegpath, fixedid, movingid, output_name="output_seg.nii.gz", worker_id=0):
    global results_queue
    cmd = f"""reg_resample -ref {fsegpath} -flo {msegpath} -trans niftyreg/output_{fixedid}_{movingid}.nii.gz -res {output_name} -inter 0 -omp 8"""
    cmd = trimspace(cmd)
    warppath = f"./niftyreg/output_{fixedid}_{movingid}.nii.gz"
    print(cmd)
    call(cmd.split())
    # # load volumes and put into results
    fixedlab = nib.load(fsegpath).get_fdata().squeeze()
    movedlab = nib.load(output_name).get_fdata().squeeze()
    results_queue.put((fixedid, movingid, compute_metrics(fixedlab, movedlab, warppath, 'niftyreg', fsegpath)))

def demons_transform_fn(fsegpath, msegpath, fixedid, movingid, output_name="output_seg.nii.gz", worker_id=0):
    global results_queue
    # in this case computation of moved lab is done in python
    warppath = f"./Demons/output_{fixedid}_{movingid}.h5"
    tr = sitk.ReadTransform(warppath)
    fixedlab = sitk.ReadImage(fsegpath)
    movinglab = sitk.ReadImage(msegpath)
    # define resampler
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixedlab)
    resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(tr)
    movedlab = sitk.GetArrayFromImage(resampler.Execute(movinglab))
    fixedlab = sitk.GetArrayFromImage(fixedlab)
    results_queue.put((fixedid, movingid, compute_metrics(fixedlab, movedlab, warppath, 'demons')))

if __name__ == '__main__':
    # get args
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo", required=True, choices=['ants', 'greedy', 'niftyreg', 'demons'])
    parser.add_argument('--dry-run', action='store_true')
    parser.add_argument('--output_name', default='output_seg.nii.gz')
    args = parser.parse_args()

    algo = args.algo
    if algo == 'ants':
        transform = ants_transform_fn
    elif algo == 'greedy':
        transform = greedy_transform_fn
    elif algo == 'niftyreg':
        transform = nifty_transform_fn
    elif algo == 'demons':
        transform = demons_transform_fn
    else:
        raise ValueError

    # perform evaluation on the pairs
    with open("reverse_subjects_OASIS.txt", "r") as fi:
        subjects = list(filter(lambda x: len(x) > 0, fi.read().split("\n")))

    moving, fixed = subjects[:-1], subjects[1:]
    # just select 3 examples
    if args.dry_run:
        moving, fixed = moving[:3], fixed[:3]

    for f, m in tqdm(list(zip(fixed, moving))):
        fsegpath = f"../neurite-OASIS/{f}/aligned_seg35.nii.gz"
        msegpath = f"../neurite-OASIS/{m}/aligned_seg35.nii.gz"
        # run script
        transform(fsegpath, msegpath, f, m, output_name=args.algo + "_" + args.output_name)

    # results are stored in queue
    results_dict = dict()
    while not results_queue.empty():
        f, m, res = results_queue.get()
        results_dict[(f, m)] = res
    # print if dry run, else save
    if args.dry_run:
        pprint.pprint(results_dict)
    with open("results/metrics_{}.pkl".format(args.algo), "wb") as fi:
        pkl.dump(results_dict, fi)
    print("Results written to results/metrics_{}.pkl".format(args.algo))
