import os
from argparse import ArgumentParser
import numpy as np
import torch
from Models import SYMNet,SpatialTransform, DiffeomorphicTransform, CompositionTransform
from Functions import generate_grid,save_img,save_flow, load_4D_with_header, imgnorm
import glob
from tqdm import tqdm
import nibabel as nib
from torch.nn import functional as F
from evalutils import compute_metrics
import pickle as pkl
from kleindataloader import KleinDatasets
from os import path as osp

parser = ArgumentParser()
parser.add_argument("--modelpath", type=str, required=True,
                     default='../Model_v2/SYMNet_neurite_oasis_smo30_update_80000.pth',
                    help="frequency of saving models")
parser.add_argument('--savepath', required=True)
# parser.add_argument("--savepath", type=str,
#                     dest="savepath", default='../Result',
#                     help="path for saving images")
parser.add_argument("--start_channel", type=int,
                    dest="start_channel", default=7,
                    help="number of start channels")
parser.add_argument('--dry_run', action='store_true')
parser.add_argument('--klein', action='store_true')
# parser.add_argument("--fixed", type=str,
#                     dest="fixed", default='../Data/image_A_full_size.nii.gz',
#                     help="fixed image")
# parser.add_argument("--moving", type=str,
#                     dest="moving", default='../Data/image_B_full_size.nii.gz',
#                     help="moving image")
opt = parser.parse_args()
# savepath = opt.savepath
# fixed_path = opt.fixed
# moving_path = opt.moving
# if not os.path.isdir(savepath):
#     os.mkdir(savepath)

def test():
    model = SYMNet(2, 3, opt.start_channel).cuda()
    transform = SpatialTransform().cuda()

    diff_transform = DiffeomorphicTransform(time_step=7).cuda()
    com_transform = CompositionTransform().cuda()

    model.load_state_dict(torch.load(opt.modelpath))
    model.eval()
    transform.eval()
    diff_transform.eval()
    com_transform.eval()

    grid = generate_grid(imgshape)
    grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).cuda().float()

    use_cuda = True
    device = torch.device("cuda" if use_cuda else "cpu")
    
    # fixed_img, fixed_header, fixed_affine = load_4D_with_header(fixed_path)
    # moved_img, moved_header, moved_affine = load_4D_with_header(moving_path)
    # norm = True
    # if norm:
    #     fixed_img = imgnorm(fixed_img)
    #     moved_img = imgnorm(moved_img)

    # fixed_img = torch.from_numpy(fixed_img).float().to(device).unsqueeze(dim=0)
    # moved_img = torch.from_numpy(moved_img).float().to(device).unsqueeze(dim=0)
    datapath = "/mnt/anon_data2/neurite-OASIS/"
    names = sorted(glob.glob(datapath + '/OASIS_OAS1_*_MR1/aligned_norm.nii.gz'))
    labels = sorted(glob.glob(datapath + '/OASIS_OAS1_*_MR1/aligned_seg35.nii.gz'))
    N = len(names)-1

    # collect results here
    results_dict = dict()

    with torch.no_grad():
        for i in tqdm(range(N)):
            fixed_img = torch.from_numpy(nib.load(names[i]).get_fdata()).float()[None, None].cuda()
            moving_img = torch.from_numpy(nib.load(names[i+1]).get_fdata()).float()[None, None].cuda()
            fid, mid = names[i].split("/")[-2], names[i+1].split("/")[-2]
            # load labels  (HWD)
            fixed_label = torch.from_numpy(nib.load(labels[i]).get_fdata()).long()[None]
            moving_label = torch.from_numpy(nib.load(labels[i+1]).get_fdata()).long()[None]
            fixed_label = (F.one_hot(fixed_label, num_classes=36)[..., 1:]).permute(0, 4, 1, 2, 3).contiguous().float().cuda()
            moving_label = (F.one_hot(moving_label, num_classes=36)[..., 1:]).permute(0, 4, 1, 2, 3).contiguous().float().cuda()

            F_xy, F_yx = model(fixed_img, moving_img)
            F_X_Y_half = diff_transform(F_xy, grid, range_flow)
            F_Y_X_half = diff_transform(F_yx, grid, range_flow)
            F_X_Y_half_inv = diff_transform(-F_xy, grid, range_flow)
            F_Y_X_half_inv = diff_transform(-F_yx, grid, range_flow)
            F_X_Y = com_transform(F_X_Y_half, F_Y_X_half_inv, grid, range_flow)
            F_Y_X = com_transform(F_Y_X_half, F_X_Y_half_inv, grid, range_flow)

            # F_BA = F_Y_X.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
            # F_BA = F_BA.astype(np.float32) * range_flow
            
            # F_AB = F_X_Y.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
            # F_AB = F_AB.astype(np.float32) * range_flow
            
            # warped_B = transform(moved_img, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid).data.cpu().numpy()[0, 0, :, :, :]
            # size [B, C, H, W, D]
            moved_label = transform(moving_label, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid)
            # append to results
            results_dict[(fid, mid)] = compute_metrics(fixed_label, moved_label, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, method='symnet')
            print({k: np.mean(v) for k, v in results_dict[(fid, mid)].items()})
            if opt.dry_run:
                if i>=4:
                    break
    # save results
    with open(opt.savepath, 'wb') as fi:
        pkl.dump(results_dict, fi)
    print("Saved results to ", opt.savepath)


def kleintest(dataset, isotropic, crop, savepath):
    print(f"Testing on {dataset} with isotropic={isotropic} and crop={crop}")
    model = SYMNet(2, 3, opt.start_channel).cuda()
    transform = SpatialTransform().cuda()

    diff_transform = DiffeomorphicTransform(time_step=7).cuda()
    com_transform = CompositionTransform().cuda()

    model.load_state_dict(torch.load(opt.modelpath))
    model.eval()
    transform.eval()
    diff_transform.eval()
    com_transform.eval()
    # get dataset
    dataset = KleinDatasets(dataset=dataset, isotropic=isotropic, crop=crop, dry_run=opt.dry_run)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
    imgshape = dataset.getimgsize()

    grid = generate_grid(imgshape)
    grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).cuda().float()

    use_cuda = True
    device = torch.device("cuda" if use_cuda else "cpu")
    
    # collect results here
    results_dict = dict()

    with torch.no_grad():
        for i, batch in tqdm(enumerate(dataloader), total=len(dataset)):
            # # fixed_img = torch.from_numpy(nib.load(names[i]).get_fdata()).float()[None, None].cuda()
            # # moving_img = torch.from_numpy(nib.load(names[i+1]).get_fdata()).float()[None, None].cuda()
            # fid, mid = names[i].split("/")[-2], names[i+1].split("/")[-2]
            moving_img, fixed_img, moving_label, fixed_label = batch
            moving_img, fixed_img = moving_img.cuda(), fixed_img.cuda()
            labelmax = int(max(torch.max(moving_label).item(), torch.max(fixed_label).item()))
            moving_label = F.one_hot(moving_label[0], num_classes=labelmax+1)[..., 1:].permute(0, 4, 1, 2, 3).contiguous().float().cuda()
            fixed_label = F.one_hot(fixed_label[0], num_classes=labelmax+1)[..., 1:].permute(0, 4, 1, 2, 3).contiguous().float().cuda()

            fid, mid = dataset.pair_ids[i]

            # warp
            F_xy, F_yx = model(fixed_img, moving_img)
            F_X_Y_half = diff_transform(F_xy, grid, range_flow)
            F_Y_X_half = diff_transform(F_yx, grid, range_flow)
            F_X_Y_half_inv = diff_transform(-F_xy, grid, range_flow)
            F_Y_X_half_inv = diff_transform(-F_yx, grid, range_flow)
            F_X_Y = com_transform(F_X_Y_half, F_Y_X_half_inv, grid, range_flow)
            F_Y_X = com_transform(F_Y_X_half, F_X_Y_half_inv, grid, range_flow)

            moved_label = (transform(moving_label, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid)>=0.5).float()
            # append to results
            ret = compute_metrics(fixed_label, moved_label, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, method='symnet', onlydice=True, labelmax=labelmax)
            results_dict[(fid, mid)] = ret
            print({k: (np.mean(v), np.array(v).shape) for k, v in ret.items()})
    # save results
    with open(savepath, 'wb') as fi:
        pkl.dump(results_dict, fi)
    print("Saved results to ", savepath)


if __name__ == '__main__':
    imgshape = (160, 192, 224)
    range_flow = 100
    if opt.klein:
        # for dataset in ['MGH10', 'CUMC12', 'IBSR18', 'LPBA40']:
        for dataset in ['IBSR18', 'CUMC12', 'MGH10', 'LPBA40']:
            for isotropic in [True, False]:
                for crop in [True, False]:
                    isostr = "isotropic" if isotropic else "anisotropic"
                    cropstr = "crop" if crop else "nocrop"
                    savepath = osp.dirname(opt.modelpath) + f"/results_{dataset}_{isostr}_{cropstr}.pkl"
                    try:
                        kleintest(dataset, isotropic, crop, savepath)
                    except Exception as e:
                        print("Error in ", dataset, isotropic, crop)
                        print(e)
    else:
        test()
