import glob, sys
import os, losses, utils
from torch.utils.data import DataLoader
from data import datasets, trans
import numpy as np
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from natsort import natsorted
from models.TransMorph import CONFIGS as CONFIGS_TM
import models.TransMorph as TransMorph
from scipy.ndimage.interpolation import map_coordinates, zoom
from torch.nn import functional as F
import argparse
import evalutils
import pickle as pkl
from tqdm import tqdm
from kleindataloader import KleinDatasets

def main(args):
    if args.klein_data is not None:
        print(f"Running {args.klein_data} with isotropic = {args.isotropic}, crop = {args.crop}.")
    model_idx = -1
    model_folder = args.model_folder
    model_dir = 'experiments/' + model_folder
    data_dir = "/mnt/anon_data2/neurite-OASIS/"
    if args.modelsize == 'regular':
        config = CONFIGS_TM['TransMorph']
    elif args.modelsize == 'large':
        config = CONFIGS_TM['TransMorph-Large']

    model = TransMorph.TransMorph(config)
    best_model = torch.load(natsorted(glob.glob(model_dir + "*dsc*pth.tar"))[model_idx])['state_dict']
    # print('Best model: {}'.format(natsorted(os.listdir(model_dir))[model_idx]))
    print("Best model: {}".format(natsorted(glob.glob(model_dir + "*dsc*pth.tar"))[model_idx]))
    model.load_state_dict(best_model)
    model.cuda()
    reg_model = utils.register_model(config.img_size, 'nearest')
    reg_model.cuda()

    val_composed = transforms.Compose([trans.NumpyType((np.float32, np.int16))])
    if args.klein_data is None:
        val_set = datasets.OASISNiftiDataset(data_dir, split='valmax', transforms=val_composed)
    else:
        val_set = KleinDatasets(dataset=args.klein_data, isotropic=args.isotropic, crop=args.crop, dry_run=args.dry_run)

    val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
    results_dict = {}
    model.eval()
    with torch.no_grad():
        for i, data in tqdm(enumerate(val_loader), total=len(val_loader)):
            x, y, x_seg, y_seg = data
            x, y, x_seg, y_seg = [tensor.cuda() for tensor in [x, y, x_seg, y_seg]]
            # make the segmentations one-hot
            labelmax = max(torch.max(x_seg).item(), torch.max(y_seg).item())
            # create one hot
            x_seg_oh = F.one_hot(x_seg.long(), num_classes=labelmax+1)[..., 1:]
            x_seg_oh = torch.squeeze(x_seg_oh, 1)
            x_seg_oh = x_seg_oh.permute(0, 4, 1, 2, 3).contiguous()
            # same with y_seg
            y_seg_oh = F.one_hot(y_seg.long(), num_classes=labelmax+1)[..., 1:]
            y_seg_oh = torch.squeeze(y_seg_oh, 1)
            y_seg_oh = y_seg_oh.permute(0, 4, 1, 2, 3).contiguous()
            # run the model
            output, flow = model(torch.cat((x, y), dim=1))
            x_seg_trans = (model.spatial_trans(x_seg_oh.float(), flow.float()) >= 0.5).float()
            # compute warp = grid + flow (this is in the forward pass of the spatialtransformer)
            # this is in the pixel coordinates == physical coordinates (in this case)
            warp = model.spatial_trans.grid + flow
            # get id
            try:
                fixed_id = val_set.paths[i][0].split("/")[-2]
                moving_id = val_set.paths[i+1][0].split("/")[-2]
            except:
                fixed_id, moving_id = val_set.pair_ids[i]
            # compute metrics
            ret = evalutils.compute_metrics(y_seg_oh, x_seg_trans, warp, onlydice=args.klein_data is not None, labelmax=labelmax)
            results_dict[(fixed_id, moving_id)] = ret

            for key, val in ret.items():
                print(key, np.mean(val))

            if args.dry_run:
                # print(results_dict[(fixed_id, moving_id)])
                if i>=1 and not args.klein_data:
                    break
    # write pickle
    iso_str = "isotropic" if args.isotropic else "anisotropic"
    results = "results.pkl" if args.klein_data is None else "results_{}_{}.pkl".format(args.klein_data, iso_str)
    # print(results)
    # input("Press Enter to continue...")
    with open(model_dir + results, 'wb') as f:
        pkl.dump(results_dict, f)
    print("Saved results to {}".format(model_dir + results))

if __name__ == '__main__':
    '''
    GPU configuration
    '''
    # GPU_iden = 1
    # GPU_num = torch.cuda.device_count()
    # print('Number of GPU: ' + str(GPU_num))
    # for GPU_idx in range(GPU_num):
    #     GPU_name = torch.cuda.get_device_name(GPU_idx)
    #     print('     GPU #' + str(GPU_idx) + ': ' + GPU_name)
    # torch.cuda.set_device(GPU_iden)
    # GPU_avai = torch.cuda.is_available()
    # print('Currently using: ' + torch.cuda.get_device_name(GPU_iden))
    # print('If the GPU is available? ' + str(GPU_avai))

    # get args
    parser = argparse.ArgumentParser(description='TransMorph')
    parser.add_argument('--model_folder', type=str, required=True)
    parser.add_argument('--dry_run', action='store_true')
    parser.add_argument('--isotropic', action='store_true')
    parser.add_argument('--crop', action='store_true')
    parser.add_argument('--klein_data', type=str, default=None)
    parser.add_argument('--allklein', action='store_true')
    parser.add_argument('--modelsize', default='regular')

    args = parser.parse_args()

    if args.allklein:
        # run all klein data
        print("Running all klein data")
        for k in ['IBSR18', 'CUMC12', 'MGH10', 'LPBA40']:
            for iso in [True, False]:
                for crop in [True,]:
                    args.klein_data = k
                    args.isotropic = iso
                    args.crop = crop
                    main(args)
                    # except:
                    #     print(f"Failed for {k} with isotropic = {iso}, crop = {crop}.")
    else:
        main(args)
