import os, time, shutil
import numpy as np
import torch
from sklearn.metrics import confusion_matrix
from scipy import ndimage
from scipy.ndimage import label
from functools import partial
import monai
from monai.inferers import sliding_window_inference
from monai.data import load_decathlon_datalist
from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged
from monai import transforms, data
from networks.swin3d_unetrv2 import SwinUNETR as SwinUNETR_v2
import nibabel as nib

import warnings
warnings.filterwarnings("ignore")

import argparse
parser = argparse.ArgumentParser(description='find best chekpoint')

parser.add_argument('--val_dir', default=None, type=str)
parser.add_argument('--json_dir', default=None, type=str)
parser.add_argument('--log_dir', default=None, type=str)
parser.add_argument('--sw_batch_size', default=64, type=int)
parser.add_argument('--swin_type', default='tiny', type=str)
parser.add_argument('--suffix', default=None, type=str)
parser.add_argument('--val_overlap', default=0.5, type=float)
parser.add_argument('--num_classes', default=3, type=int)
parser.add_argument('--val_every', default=100, type=int)
parser.add_argument('--max_epochs', default=1000, type=int)

parser.add_argument('--syn',action='store_true')
parser.add_argument('--model', default='swin_unetrv2', type=str)


def cal_dice(pred, true):
    intersection = np.sum(pred[true==1]) * 2.0
    dice = intersection / (np.sum(pred) + np.sum(true) + 1e-9)
    return dice

def _get_loader(args):
    val_data_dir = args.val_dir
    datalist_json = args.json_dir 
    val_org_transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image", "label"]),
            transforms.AddChanneld(keys=["image", "label"]),
            transforms.Orientationd(keys=["image"], axcodes="RAS"),
            transforms.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear")),
            transforms.ScaleIntensityRanged(keys=["image"], a_min=-21, a_max=189, b_min=0.0, b_max=1.0, clip=True),
            transforms.SpatialPadd(keys=["image"], mode="minimum", spatial_size=[96, 96, 96]),
            transforms.ToTensord(keys=["image", "label"]),
        ]
    )
    val_files = load_decathlon_datalist(datalist_json, True, "validation", base_dir=val_data_dir)
    val_org_ds = data.Dataset(val_files, transform=val_org_transform)
    val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True)

    post_transforms = Compose([
        Invertd(
            keys="pred",
            transform=val_org_transform,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
        ),
        AsDiscreted(keys="pred", argmax=True, to_onehot=args.num_classes),
        AsDiscreted(keys="label", to_onehot=args.num_classes),
    ])
    
    return val_org_loader, post_transforms


def load_model(args, epoch):
    inf_size = [96, 96, 96]
    if args.model == 'swin_unetrv2':
        if args.swin_type == 'tiny':
            feature_size=12
        elif args.swin_type == 'small':
            feature_size=24
        elif args.swin_type == 'base':
            feature_size=48
        model = SwinUNETR_v2(in_channels=1,
                            out_channels=args.num_classes,
                            img_size=(96, 96, 96),
                            feature_size=feature_size,
                            patch_size=2,
                            depths=[2, 2, 2, 2],
                            num_heads=[3, 6, 12, 24],
                            window_size=[7, 7, 7])
        
    elif args.model == 'unet':
        from monai.networks.nets import UNet 
        model = UNet(
                    spatial_dims=3,
                    in_channels=1,
                    out_channels=args.num_classes,
                    channels=(16, 32, 64, 128, 256),
                    strides=(2, 2, 2, 2),
                    num_res_units=2,
                )
    else:
        raise RuntimeError("Unsupported Model.")

    # load model
    model_dict = torch.load(os.path.join(args.log_dir, f'model_{epoch}.pt'))
    model.load_state_dict(model_dict['state_dict'])
    print(f'Load EPOCH:{epoch} model_{epoch}!', end='\t')

    # model interference
    model = model.cuda()
    model_inferer = partial(sliding_window_inference, roi_size=inf_size, sw_batch_size=args.sw_batch_size, predictor=model, overlap=args.val_overlap, mode='gaussian')
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('Total parameters count', pytorch_total_params)

    return model, model_inferer


def evaluation_model(model, model_inferer, val_loader, post_transforms):
    model.eval()
    start_time = time.time()
    liver_dices = []
    tumor_dices = []
    with torch.no_grad():
        for idx, val_data in enumerate(val_loader):
            val_inputs = val_data["image"].cuda()
            name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0]

            val_data["pred"] = model_inferer(val_inputs)
            val_data = [post_transforms(i) for i in data.decollate_batch(val_data)]
            val_outputs, val_labels = val_data[0]['pred'], val_data[0]['label'] 
            
            val_outputs, val_labels = val_outputs.detach().cpu().numpy(), val_labels.detach().cpu().numpy()

            current_liver_dice = cal_dice(val_outputs[1,...], val_labels[1,...])
            current_tumor_dice = cal_dice(val_outputs[2,...], val_labels[2,...])
            
            print(name, val_outputs[0].shape, \
                'dice: [{:.4f}  {:.4f}]'.format(current_liver_dice, current_tumor_dice), \
                'time {:.2f}s'.format(time.time() - start_time))
            liver_dices.append(current_liver_dice)
            tumor_dices.append(current_tumor_dice)
    liver_dice = np.mean(liver_dices)
    tumor_dice = np.mean(tumor_dices)
    
    return liver_dice, tumor_dice

def main():

    args = parser.parse_args()
    runs_name = args.log_dir.split('/')[-1]
    args.runs_name = runs_name
    print("MAIN Argument values:")
    for k, v in vars(args).items():
        print(k, '=>', v)
    print('-----------------')

    torch.cuda.set_device(0) #use this default device (same as args.device if not distributed)
    torch.backends.cudnn.benchmark = True

    val_loader, post_transforms = _get_loader(args)
    print(args.model)

    save_txt_name   = args.suffix + '.txt'

    # best_dice = 0.0
    for epoch in range(args.val_every, args.max_epochs + args.val_every, args.val_every):
        print("="*64)
        print(f"Start Evaluate model_{epoch}")
        # load checkpoint
        model, model_inferer = load_model(args, epoch)

        current_liver_dice, current_tumor_dice = evaluation_model(model, model_inferer, val_loader, post_transforms)
        mean_dice = (current_liver_dice + current_tumor_dice) / 2

        print("Epoch: {} || Current Dice: [{:.4f} {:.4f} {:.4f}]".format(epoch, current_liver_dice, current_tumor_dice, mean_dice))
        # save to file
        with open(os.path.join(args.log_dir, save_txt_name), 'a') as f:
            print("Epoch: {} || Current Dice: [{:.4f} {:.4f} {:.4f}]".format(epoch, current_liver_dice, current_tumor_dice, mean_dice), file=f)
        print("="*64)


if __name__ == "__main__":
    main()
