import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import argparse
import numpy as np
import voxelmorph as vxm
import tensorflow as tf
import evalutils
from kleindataloader import KleinDatasets
from tqdm import tqdm
from torch.nn import functional as F
import torch
import os
import pickle as pkl
import time
from glob import glob
from natsort import natsorted
import nibabel as nib

# parse commandline args
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, help='keras model for nonlinear registration')
parser.add_argument('--warp', help='output warp deformation filename')
parser.add_argument('--dry_run', action='store_true', help='dry run')
parser.add_argument('--multichannel', action='store_true',
                    help='specify that data has multiple channels')
parser.add_argument('--gpu_id', type=int, default=0)
args = parser.parse_args()

# tensorflow device handling
device, nb_devices = vxm.tf.utils.setup_device(args.gpu_id)

# load moving and fixed images
# add_feat_axis = not args.multichannel
# # moving = vxm.py.utils.load_volfile(args.moving, add_batch_axis=True, add_feat_axis=add_feat_axis)
# fixed, fixed_affine = vxm.py.utils.load_volfile(
#     args.fixed, add_batch_axis=True, add_feat_axis=add_feat_axis, ret_affine=True)

def dataset():
    # create a generator
    images = natsorted(glob("/mnt/anon_data2/neurite-OASIS/*/aligned_norm.nii.gz"))
    segs = natsorted(glob("/mnt/anon_data2/neurite-OASIS/*/aligned_seg35.nii.gz"))
    N = len(images)
    # for img, seg in zip(images, segs):
    for i in range(N-1):
        fiximg, movimg = images[i], images[i+1]
        fixseg, movseg = segs[i], segs[i+1]
        fid = fiximg.split("/")[-2]
        mid = movimg.split("/")[-2]
        fiximg, movimg = nib.load(fiximg).get_fdata().squeeze(), nib.load(movimg).get_fdata().squeeze()
        fiximg, movimg = fiximg[None, ..., None], movimg[None, ..., None]
        # load segmentations
        fixseg, movseg = nib.load(fixseg).get_fdata().squeeze(), nib.load(movseg).get_fdata().squeeze()
        fixseg, movseg = fixseg[None], movseg[None]
        yield movimg, fiximg, movseg, fixseg, fid, mid


def main(savepath):
    # get dataset
    results_dict = {}
    inshape = (160, 192, 224)
    print(device)
    gen = dataset()
    # with tf.device(device):
    config = dict(inshape=inshape, input_model=None)
    with tf.device(device):
        model = vxm.networks.VxmDense.load(args.model, **config)
        transform = None
        # run results
        for i, batch in tqdm(enumerate(gen), total=413):
            # [1, H, W, D]
            moving_img, fixed_img, moving_seg, fixed_seg, fid, mid = batch
            # convert moving and fixedseg
            maxlabel = int(max(moving_seg.max(), fixed_seg.max()))
            moving_seg = tf.one_hot(moving_seg, depth=maxlabel+1)[..., 1:]
            fixed_seg = tf.one_hot(fixed_seg, depth=maxlabel+1)[..., 1:]
            nb_feats = moving_seg.shape[-1]
            if transform is None:
                transform = vxm.networks.Transform(inshape, nb_feats=nb_feats)
            # run warp
            # inshape = moving_img.shape[1:-1]
            # nb_feats = moving_img.shape[-1]
            # with tf.device(device):
            a = time.time()
            warp = model.register(moving_img, fixed_img)
            moved_seg = transform.predict([moving_seg.numpy(), warp])
            b = time.time()
            print(b - a)
            # print shape
            moved_seg = (torch.from_numpy(moved_seg)>=0.5).float()
            # print(moved_seg.shape, fixed_seg.shape, moving_seg.shape) 
            # input("hi")
            # ret = evalutils.compute_metrics(fixed_seg)
            # compute metrics
            moved_seg = torch.from_numpy(moved_seg.numpy()).permute(0, 4, 1, 2, 3)
            fixed_seg = torch.from_numpy(fixed_seg.numpy()).permute(0, 4, 1, 2, 3)
            print(moved_seg.shape, fixed_seg.shape)
            ret = evalutils.compute_metrics(moved_seg, fixed_seg, warp, onlydice=False, labelmax=maxlabel, method='fireants')
            results_dict[(fid, mid)] = ret
            print({k: (np.mean(v), np.array(v).shape) for k, v in ret.items()})
            if args.dry_run:
                break
        
        # save results
        print(f"Saving results to {savepath}.")
        with open(savepath, 'wb') as fi:
            pkl.dump(results_dict, fi)


if __name__ == "__main__":
    os.makedirs("results_oasis", exist_ok=True)
    savepath = f"results_oasis/{os.path.basename(args.model)}.pkl"
    main(savepath)
