import os
from argparse import ArgumentParser
import numpy as np
import torch
from Models import SYMNet,SpatialTransform, DiffeomorphicTransform, CompositionTransform
from Functions import generate_grid,save_img,save_flow, load_4D_with_header, imgnorm
import glob
from tqdm import tqdm
import nibabel as nib
from torch.nn import functional as F
from evalutils import compute_metrics
import pickle as pkl
from kleindataloader import KleinDatasets
from os import path as osp

parser = ArgumentParser()
parser.add_argument("--modelpath", type=str, required=True,
                     default='../Model_v2/SYMNet_neurite_oasis_smo30_update_80000.pth',
                    help="frequency of saving models")
# parser.add_argument('--savepath', required=True)
parser.add_argument("--start_channel", type=int,
                    dest="start_channel", default=7,
                    help="number of start channels")
parser.add_argument('--dry_run', action='store_true')
parser.add_argument('--klein', action='store_true')
opt = parser.parse_args()

def get_lipschitz(flow):
    # flow: [H W D 3]
    jac = [[np.gradient(flow[..., j], axis=i) for i in range(3)] for j in range(3)]
    jac = np.stack([np.stack(x, axis=-1) for x in jac], axis=-1)
    norm = np.linalg.norm(jac, axis=(-2, -1), ord='nuc')
    L = np.max(norm)
    Mstar = np.log2(L)
    return L, Mstar

def get_DetJac(warp):
    # warp: [H W D 3]
    # print(warp.shape)
    jac = [[np.gradient(warp[..., j], axis=i) for i in range(3)] for j in range(3)]
    jac = np.stack([np.stack(x, axis=-1) for x in jac], axis=-1)[1:-1, 1:-1, 1:-1]
    detjac = np.linalg.det(jac)
    print(detjac.shape)
    return detjac

def test():
    model = SYMNet(2, 3, opt.start_channel).cuda()
    transform = SpatialTransform().cuda()

    diff_transform = DiffeomorphicTransform(time_step=7).cuda()
    com_transform = CompositionTransform().cuda()

    model.load_state_dict(torch.load(opt.modelpath))
    model.eval()
    transform.eval()
    diff_transform.eval()
    com_transform.eval()

    grid = generate_grid(imgshape)
    grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).cuda().float()

    use_cuda = True
    device = torch.device("cuda" if use_cuda else "cpu")
    
    datapath = "/mnt/anon_data2/neurite-OASIS/"
    names = sorted(glob.glob(datapath + '/OASIS_OAS1_*_MR1/aligned_norm.nii.gz'))
    labels = sorted(glob.glob(datapath + '/OASIS_OAS1_*_MR1/aligned_seg35.nii.gz'))
    N = len(names)-1

    # collect results here
    results_dict = dict()

    with torch.no_grad():
        for i in tqdm(range(N)):
            fixed_img = torch.from_numpy(nib.load(names[i]).get_fdata()).float()[None, None].cuda()
            moving_img = torch.from_numpy(nib.load(names[i+1]).get_fdata()).float()[None, None].cuda()
            fid, mid = names[i].split("/")[-2], names[i+1].split("/")[-2]
            # load labels  (HWD)
            fixed_label = torch.from_numpy(nib.load(labels[i]).get_fdata()).long()[None]
            moving_label = torch.from_numpy(nib.load(labels[i+1]).get_fdata()).long()[None]
            fixed_label = (F.one_hot(fixed_label, num_classes=36)[..., 1:]).permute(0, 4, 1, 2, 3).contiguous().float().cuda()
            moving_label = (F.one_hot(moving_label, num_classes=36)[..., 1:]).permute(0, 4, 1, 2, 3).contiguous().float().cuda()

            F_xy, F_yx = model(fixed_img, moving_img)
            F_X_Y_half = diff_transform(F_xy, grid, range_flow)
            F_Y_X_half = diff_transform(F_yx, grid, range_flow)
            F_X_Y_half_inv = diff_transform(-F_xy, grid, range_flow)
            F_Y_X_half_inv = diff_transform(-F_yx, grid, range_flow)
            F_X_Y = com_transform(F_X_Y_half, F_Y_X_half_inv, grid, range_flow)
            F_Y_X = com_transform(F_Y_X_half, F_X_Y_half_inv, grid, range_flow)

            # Get norm
            F_xynp = range_flow * F_xy.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]   # [H, W, D, 3]
            L, Mstar = get_lipschitz(F_xynp)

            # get Jacobian now
            warp = range_flow * F_X_Y_half + grid.permute(0, 4, 1, 2, 3)
            detjac = -get_DetJac(warp.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :])  # due to rolled axis
            print(L, Mstar, (detjac<=0).mean())

            # F_BA = F_Y_X.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
            # F_BA = F_BA.astype(np.float32) * range_flow
            
            # F_AB = F_X_Y.permute(0, 2, 3, 4, 1).data.cpu().numpy()[0, :, :, :, :]
            # F_AB = F_AB.astype(np.float32) * range_flow
            
            # warped_B = transform(moved_img, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid).data.cpu().numpy()[0, 0, :, :, :]
            # size [B, C, H, W, D]
            # moved_label = transform(moving_label, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid)
            # # append to results
            # results_dict[(fid, mid)] = compute_metrics(fixed_label, moved_label, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, method='symnet')
            # print({k: np.mean(v) for k, v in results_dict[(fid, mid)].items()})


def kleintest(dataset, isotropic, crop, savepath):
    print(f"Testing on {dataset} with isotropic={isotropic} and crop={crop}")
    model = SYMNet(2, 3, opt.start_channel).cuda()
    transform = SpatialTransform().cuda()

    diff_transform = DiffeomorphicTransform(time_step=7).cuda()
    com_transform = CompositionTransform().cuda()

    model.load_state_dict(torch.load(opt.modelpath))
    model.eval()
    transform.eval()
    diff_transform.eval()
    com_transform.eval()
    # get dataset
    dataset = KleinDatasets(dataset=dataset, isotropic=isotropic, crop=crop, dry_run=opt.dry_run)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
    imgshape = dataset.getimgsize()

    grid = generate_grid(imgshape)
    grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).cuda().float()

    use_cuda = True
    device = torch.device("cuda" if use_cuda else "cpu")
    
    # collect results here
    results_dict = dict()

    with torch.no_grad():
        for i, batch in tqdm(enumerate(dataloader), total=len(dataset)):
            # # fixed_img = torch.from_numpy(nib.load(names[i]).get_fdata()).float()[None, None].cuda()
            # # moving_img = torch.from_numpy(nib.load(names[i+1]).get_fdata()).float()[None, None].cuda()
            # fid, mid = names[i].split("/")[-2], names[i+1].split("/")[-2]
            moving_img, fixed_img, moving_label, fixed_label = batch
            moving_img, fixed_img = moving_img.cuda(), fixed_img.cuda()
            labelmax = int(max(torch.max(moving_label).item(), torch.max(fixed_label).item()))
            moving_label = F.one_hot(moving_label[0], num_classes=labelmax+1)[..., 1:].permute(0, 4, 1, 2, 3).contiguous().float().cuda()
            fixed_label = F.one_hot(fixed_label[0], num_classes=labelmax+1)[..., 1:].permute(0, 4, 1, 2, 3).contiguous().float().cuda()

            fid, mid = dataset.pair_ids[i]

            # warp
            F_xy, F_yx = model(fixed_img, moving_img)
            F_X_Y_half = diff_transform(F_xy, grid, range_flow)
            F_Y_X_half = diff_transform(F_yx, grid, range_flow)
            F_X_Y_half_inv = diff_transform(-F_xy, grid, range_flow)
            F_Y_X_half_inv = diff_transform(-F_yx, grid, range_flow)
            F_X_Y = com_transform(F_X_Y_half, F_Y_X_half_inv, grid, range_flow)
            F_Y_X = com_transform(F_Y_X_half, F_X_Y_half_inv, grid, range_flow)

            moved_label = (transform(moving_label, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, grid)>=0.5).float()
            # append to results
            ret = compute_metrics(fixed_label, moved_label, F_Y_X.permute(0, 2, 3, 4, 1) * range_flow, method='symnet', onlydice=True, labelmax=labelmax)
            results_dict[(fid, mid)] = ret
            print({k: (np.mean(v), np.array(v).shape) for k, v in ret.items()})
    # save results
    with open(savepath, 'wb') as fi:
        pkl.dump(results_dict, fi)
    print("Saved results to ", savepath)

if __name__ == '__main__':
    imgshape = (160, 192, 224)
    range_flow = 500
    if opt.klein:
        # for dataset in ['MGH10', 'CUMC12', 'IBSR18', 'LPBA40']:
        for dataset in ['IBSR18', 'CUMC12', 'MGH10', 'LPBA40']:
            for isotropic in [True, False]:
                for crop in [True, False]:
                    isostr = "isotropic" if isotropic else "anisotropic"
                    cropstr = "crop" if crop else "nocrop"
                    savepath = osp.dirname(opt.modelpath) + f"/results_{dataset}_{isostr}_{cropstr}.pkl"
                    try:
                        kleintest(dataset, isotropic, crop, savepath)
                    except Exception as e:
                        print("Error in ", dataset, isotropic, crop)
                        print(e)
    else:
        test()
