import os
from argparse import ArgumentParser

import numpy as np
import torch
from tqdm import tqdm
from torch.nn import functional as F

from Functions import generate_grid_unit, save_img, save_flow, transform_unit_flow_to_flow, load_4D
from Functions import generate_grid, Dataset_epoch, transform_unit_flow_to_flow_cuda, \
    generate_grid_unit, Dataset_oasis
import torch.utils.data as Data
from miccai2020_model_stage import Miccai2020_LDR_laplacian_unit_disp_add_lvl1, \
    Miccai2020_LDR_laplacian_unit_disp_add_lvl2, Miccai2020_LDR_laplacian_unit_disp_add_lvl3, SpatialTransform_unit
import evalutils
import pickle as pkl
from kleindataloader import KleinDatasets

parser = ArgumentParser()
parser.add_argument("--modelpath", type=str, required=True,
                    help="Pre-trained Model path")
parser.add_argument("--start_channel", type=int,
                    dest="start_channel", default=7,
                    help="number of start channels")
parser.add_argument('--dry_run', action='store_true', help="dry run")
parser.add_argument('--savepath', required=True, type=str)
parser.add_argument('--klein', action='store_true', help="klein data")
opt = parser.parse_args()

savepath = opt.savepath
start_channel = opt.start_channel

# test klein function
def testklein(dataset, isotropic, crop, savepath):
    ### test klein function
    dataset = KleinDatasets(dataset=dataset, isotropic=isotropic, crop=crop, dry_run=opt.dry_run)
    print("dataset has {} pairs".format(len(dataset)))
    print("evaluating with isotropic = {}, crop = {}".format(isotropic, crop))
    imgshape = dataset.getimgsize()
    imgshape_2 = (imgshape[0]//2, imgshape[1]//2, imgshape[2]//2)
    imgshape_4 = (imgshape[0]//4, imgshape[1]//4, imgshape[2]//4)
    val_generator = Data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)

    # create model
    model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(2, 3, start_channel, is_train=True, imgshape=imgshape_4,
                                               range_flow=range_flow).cuda()
    model_lvl2 = Miccai2020_LDR_laplacian_unit_disp_add_lvl2(2, 3, start_channel, is_train=True, imgshape=imgshape_2,
                                          range_flow=range_flow, model_lvl1=model_lvl1).cuda()

    model = Miccai2020_LDR_laplacian_unit_disp_add_lvl3(2, 3, start_channel, is_train=False, imgshape=imgshape,
                                          range_flow=range_flow, model_lvl2=model_lvl2).cuda()

    transform = SpatialTransform_unit().cuda()
    model.load_state_dict(torch.load(opt.modelpath))
    print("Loaded model from {}".format(opt.modelpath))
    model.eval()
    transform.eval()

    grid = generate_grid_unit(imgshape)
    grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).cuda().float()
    # print(grid.shape)

    use_cuda = True
    device = torch.device("cuda" if use_cuda else "cpu")
    results_dict = {}
    with torch.no_grad():
        for i, (X, Y, X_seg, Y_seg) in tqdm(enumerate(val_generator), total=len(val_generator)):
            # get image names
            fid, mid = val_generator.dataset.pair_ids[i]
            # load images
            X = X.to(device).float()
            Y = Y.to(device).float()
            X_seg = X_seg.to(device)
            Y_seg = Y_seg.to(device)
            labelmax = max(torch.max(X_seg).item(), torch.max(Y_seg).item())
            # one hot
            X_seg_oh = F.one_hot(X_seg[0].long(), num_classes=labelmax+1).permute(0, 4, 1, 2, 3).float()[:, 1:].contiguous()
            Y_seg_oh = F.one_hot(Y_seg[0].long(), num_classes=labelmax+1).permute(0, 4, 1, 2, 3).float()[:, 1:].contiguous()
            # print(grid.shape, X_seg_oh.shape, Y_seg_oh.shape)
            # print(X.shape, Y.shape)
            # print(X.shape, Y.shape)
            F_X_Y = model(X, Y)
            # print(F_X_Y.shape)
            X_seg_oh_warp = (transform(X_seg_oh, F_X_Y.permute(0, 2, 3, 4, 1), grid) >= 0.5).float()
            ret = evalutils.compute_metrics(X_seg_oh_warp, Y_seg_oh, F_X_Y, method='lapirn', onlydice=True, labelmax=labelmax)
            results_dict[(fid, mid)] = ret
            print({k: np.mean(v) for k, v in ret.items()}, {k: np.array(v).shape for k, v in ret.items()})
            # if opt.dry_run:
            #     print({k: np.mean(v) for k, v in ret.items()})
            #     if i >= 4:
            #         break
    
    with open(savepath, 'wb') as f:
        pkl.dump(results_dict, f)


def test():
    model_lvl1 = Miccai2020_LDR_laplacian_unit_disp_add_lvl1(2, 3, start_channel, is_train=True, imgshape=imgshape_4,
                                               range_flow=range_flow).cuda()
    model_lvl2 = Miccai2020_LDR_laplacian_unit_disp_add_lvl2(2, 3, start_channel, is_train=True, imgshape=imgshape_2,
                                          range_flow=range_flow, model_lvl1=model_lvl1).cuda()

    model = Miccai2020_LDR_laplacian_unit_disp_add_lvl3(2, 3, start_channel, is_train=False, imgshape=imgshape,
                                          range_flow=range_flow, model_lvl2=model_lvl2).cuda()

    transform = SpatialTransform_unit().cuda()
    model.load_state_dict(torch.load(opt.modelpath))
    print("Loaded model from {}".format(opt.modelpath))
    model.eval()
    transform.eval()

    grid = generate_grid_unit(imgshape)
    grid = torch.from_numpy(np.reshape(grid, (1,) + grid.shape)).cuda().float()

    use_cuda = True
    device = torch.device("cuda" if use_cuda else "cpu")
    datapath = "/mnt/anon_data2/neurite-OASIS/"
    val_generator = Data.DataLoader(Dataset_oasis(datapath, norm=False, split='maxval'), batch_size=1,
                                         shuffle=True, num_workers=2)
    # fixed_img = load_4D(fixed_path)
    # moving_img = load_4D(moving_path)
    results_dict = {}
    with torch.no_grad():
        for i, (X, Y, X_seg, Y_seg) in tqdm(enumerate(val_generator), total=len(val_generator)):
            # get image names
            fid, mid = val_generator.dataset.index_pairs[i]
            fid, mid = val_generator.dataset.images[fid], val_generator.dataset.images[mid]
            fid, mid = fid.split("/")[-2], mid.split("/")[-2]
            # load images
            X = X.to(device).float()
            Y = Y.to(device).float()
            X_seg = X_seg.to(device)
            Y_seg = Y_seg.to(device)
            # one hot
            X_seg_oh = F.one_hot(X_seg[0].long(), num_classes=36).permute(0, 4, 1, 2, 3).float()[:, 1:].contiguous()
            Y_seg_oh = F.one_hot(Y_seg[0].long(), num_classes=36).permute(0, 4, 1, 2, 3).float()[:, 1:].contiguous()
            # print(grid.shape, X_seg_oh.shape, Y_seg_oh.shape)
            # print(X.shape, Y.shape)
            F_X_Y = model(X, Y)
            # print(F_X_Y.shape)
            X_seg_oh_warp = (transform(X_seg_oh, F_X_Y.permute(0, 2, 3, 4, 1), grid) >= 0.5).float()
            ret = evalutils.compute_metrics(X_seg_oh_warp, Y_seg_oh, F_X_Y, method='lapirn')
            results_dict[(fid, mid)] = ret
            if opt.dry_run:
                print({k: np.mean(v) for k, v in ret.items()})
                if i >= 4:
                    break
    
    with open(savepath, 'wb') as f:
        pkl.dump(results_dict, f)


    # fixed_img = torch.from_numpy(fixed_img).float().to(device).unsqueeze(dim=0)
    # moving_img = torch.from_numpy(moving_img).float().to(device).unsqueeze(dim=0)

    # with torch.no_grad():
    #     F_X_Y = model(moving_img, fixed_img)

    #     X_Y = transform(moving_img, F_X_Y.permute(0, 2, 3, 4, 1), grid).data.cpu().numpy()[0, 0, :, :, :]

    #     F_X_Y_cpu = F_X_Y.data.cpu().numpy()[0, :, :, :, :].transpose(1, 2, 3, 0)
    #     F_X_Y_cpu = transform_unit_flow_to_flow(F_X_Y_cpu)

    #     save_flow(F_X_Y_cpu, savepath+'/warpped_flow.nii.gz')
    #     save_img(X_Y, savepath+'/warpped_moving.nii.gz')
    # print("Finished")

imgshape = (160, 192, 224)
imgshape_4 = (160//4, 192//4, 224//4)
imgshape_2 = (160//2, 192//2, 224//2)

if __name__ == '__main__':
    range_flow = 0.4
    if opt.klein:
        # dataset = 'IBSR18'
        # isotropic = False
        # crop = False
        for dataset in ['MGH10', 'CUMC12', 'IBSR18', 'LPBA40']:
            for isotropic in [True, False]:
                for crop in [True, False]:
                    isostr = "isotropic" if isotropic else "anisotropic"
                    cropstr = "crop" if crop else "nocrop"
                    savepath = opt.savepath + f"/{dataset}_{isostr}_{cropstr}.pkl"
                    os.makedirs(opt.savepath, exist_ok=True)
                    try:
                        testklein(dataset, isotropic, crop, savepath)
                    except:
                        pass
    else:
        test()
