import os
import glob
import sys
from argparse import ArgumentParser
import numpy as np
import torch
import torch.nn as nn
from Models import *
from Functions import *
import torch.utils.data as Data
from natsort import natsorted
import csv
import glob
import os.path as osp
from tqdm import tqdm
import evalutils
import pickle as pkl
from kleindataloader import KleinDatasets

parser = ArgumentParser()
parser.add_argument("--start_channel", type=int,
                    dest="start_channel", default=32,
                    help="number of start channels")
parser.add_argument("--datapath", type=str,
                    dest="datapath",
                    default='/mnt/anon_data2/neurite-OASIS/',
                    help="data path for training images")
parser.add_argument('--modelpath', type=str, required=True, help="data directory with checkpoints")
parser.add_argument('--dry_run', action='store_true', help="dry run")
parser.add_argument('--klein', action='store_true', help="klein data")
opt = parser.parse_args()

start_channel = opt.start_channel
datapath = opt.datapath
modelpath = opt.modelpath

def val_klein(dataset, iso, crop, savepath):
    print("Running validation on Klein data.")
    print(f"Dataset: {dataset}, Isotropic: {iso}, Crop: {crop}")
    if not opt.dry_run and osp.exists(savepath):
        print("Save path exists")
        return
    use_cuda = True
    modelname = natsorted(glob.glob(osp.join(modelpath, 'DiceVal*pth')))[-1]
    device = torch.device("cuda" if use_cuda else "cpu")
    model = UNet(2, 3, start_channel).cuda()
    model.eval()
    model.load_state_dict(torch.load(modelname))
    print("Loaded model from {}".format(modelname))

    transform = SpatialTransform().cuda()
    diff_transform = DiffeomorphicTransform(time_step=7).cuda()

    for param in transform.parameters():
        param.requires_grad = False
        param.volatile = True
    # test_set = OASISNeuriteDataset(datapath, split='maxval')
    test_set = KleinDatasets(dataset=dataset, isotropic=iso, crop=crop, dry_run=opt.dry_run)
    test_generator = Data.DataLoader(dataset=test_set, batch_size=1, shuffle=False, num_workers=2)

    results_dict = dict()
    
    with torch.no_grad():
        for i, (mov_img, fix_img, mov_lab, fix_lab) in tqdm(enumerate(test_generator), total=len(test_generator)):
            fix_img = fix_img.cuda().float()
            mov_img = mov_img.cuda().float()
            fix_lab = fix_lab.cuda().float()
            mov_lab = mov_lab.cuda().float()
            # max label
            maxlabel = int(max(fix_lab.max().item(), mov_lab.max().item()))
            
            mov_seg = nn.functional.one_hot(mov_lab.long(), num_classes=maxlabel+1)[0, ..., 1:]
            mov_seg = mov_seg.permute(0, 4, 1, 2, 3).float().contiguous()
            fix_seg = nn.functional.one_hot(fix_lab.long(), num_classes=maxlabel+1)[0, ..., 1:]
            fix_seg = fix_seg.permute(0, 4, 1, 2, 3).float().contiguous()

            mov_id, fix_id = test_set.pair_ids[i]

            # model
            f_xy = model(mov_img, fix_img)
            moved_seg = (transform(mov_seg, f_xy.permute(0, 2, 3, 4, 1)) >= 0.5).float()
            # print(fix_id, mov_id, fix_seg.shape, moved_seg.shape)
            ret = evalutils.compute_metrics(fix_seg, moved_seg, f_xy, method='lku', onlydice=True, labelmax=maxlabel)
            results_dict[(fix_id, mov_id)] = ret
            print({k: np.mean(v) for k, v in ret.items()}, {k: np.array(v).shape for k, v in ret.items()})
    # save model
    print(f"Saving to {savepath}.")
    with open(savepath, 'wb') as f:
        pkl.dump(results_dict, f)

def val():
    use_cuda = True
    modelname = natsorted(glob.glob(osp.join(modelpath, 'DiceVal*pth')))[-1]
    device = torch.device("cuda" if use_cuda else "cpu")
    model = UNet(2, 3, start_channel).cuda()
    model.eval()
    model.load_state_dict(torch.load(modelname))
    print("Loaded model from {}".format(modelname))

    transform = SpatialTransform().cuda()
    diff_transform = DiffeomorphicTransform(time_step=7).cuda()

    for param in transform.parameters():
        param.requires_grad = False
        param.volatile = True
    test_set = OASISNeuriteDataset(datapath, split='maxval')
    test_generator = Data.DataLoader(dataset=test_set, batch_size=1, shuffle=False, num_workers=2)

    results_dict = dict()
    
    with torch.no_grad():
        for i, (mov_img, fix_img, mov_lab, fix_lab) in tqdm(enumerate(test_generator), total=len(test_generator)):
            fix_img = fix_img.cuda().float()
            mov_img = mov_img.cuda().float()
            fix_lab = fix_lab.cuda().float()
            mov_lab = mov_lab.cuda().float()
            
            mov_seg = nn.functional.one_hot(mov_lab.long(), num_classes=36)[0, ..., 1:]
            mov_seg = mov_seg.permute(0, 4, 1, 2, 3).float().contiguous()
            fix_seg = nn.functional.one_hot(fix_lab.long(), num_classes=36)[0, ..., 1:]
            fix_seg = fix_seg.permute(0, 4, 1, 2, 3).contiguous()

            fix_id = test_set.files[i].split("/")[-2]
            mov_id = test_set.files[i+1].split("/")[-2]

            # model
            f_xy = model(mov_img, fix_img)
            moved_seg = (transform(mov_seg, f_xy.permute(0, 2, 3, 4, 1)) >= 0.5).float()
            # print(fix_id, mov_id, fix_seg.shape, moved_seg.shape)
            ret = evalutils.compute_metrics(fix_seg, moved_seg, f_xy, method='lku')
            results_dict[(fix_id, mov_id)] = ret
            if opt.dry_run:
                print({k: np.mean(v) for k, v in ret.items()})
                if i>4:
                    break
    # save model
    with open(modelpath + 'results.pkl', 'wb') as f:
        pkl.dump(results_dict, f)

    
if __name__ == '__main__':
    if opt.klein:
        # run 
        for dataset in ['IBSR18', 'CUMC12', 'MGH10', 'LPBA40']:
            for iso in [True, False]:
                for crop in [True, False]:
                    isostr = 'isotropic' if iso else 'anisotropic'
                    cropstr = 'crop' if crop else 'nocrop'
                    savepath = osp.join(modelpath, f'{dataset}_{isostr}_{cropstr}.pkl')
                    try:
                        val_klein(dataset, iso, crop, savepath)
                    except:
                        print(f"Failed on {dataset}, {iso}, {crop}.")
    else:
        val()
