from data.dataset_3d_lungs import LabelledDS
from models import deeplabv3
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import math
import monai.transforms as mt
import os
from monai.metrics import DiceHelper



def get_args(known=False):
    parser = argparse.ArgumentParser(description='PyTorch Implementation')
    parser.add_argument('--json_path', type=str, default='/path/to/data/labels/test_data.json', help='path to the data')
    parser.add_argument('--image_size', type=list, default=[180, 180, 70], help='the size of images for training and testing')
    parser.add_argument('--patch_size', type=tuple, default=(144, 144, 64), help='the size of patch')
    parser.add_argument('--cuda', type=bool, default=True, help='use gpu')
    parser.add_argument('--stride_xy', type=int, default=16, help='stride in xy plane')
    parser.add_argument('--stride_z', type=int, default=6, help='stride in z plane')
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers to use for dataloader')
    parser.add_argument('--in_channels', type=int, default=1, help='input channels')
    parser.add_argument('--num_classes', type=int, default=3, help='number of target categories')
    parser.add_argument('--model_weights', type=str, default='/path/to/output/weights/best.pth', help='model weights')
    parser.add_argument('--dest_path', type=str, default='test_outputs/results', help='destination path')
    parser.add_argument('--json_key', type=str, default='val', help='which split to evaluate on from the json file')
    args = parser.parse_known_args()[0] if known else parser.parse_args()
    return args




def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=2):
    w, h, d = image.shape

    # if the size of image is less than patch_size, then padding it
    add_pad = False
    if w < patch_size[0]:
        w_pad = patch_size[0]-w
        add_pad = True
    else:
        w_pad = 0
    if h < patch_size[1]:
        h_pad = patch_size[1]-h
        add_pad = True
    else:
        h_pad = 0
    if d < patch_size[2]:
        d_pad = patch_size[2]-d
        add_pad = True
    else:
        d_pad = 0
    wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
    hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
    dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
    if add_pad:
        image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
    ww,hh,dd = image.shape

    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
    # print("{}, {}, {}".format(sx, sy, sz))
    score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
    cnt = np.zeros(image.shape).astype(np.float32)

    for x in range(0, sx):
        xs = min(stride_xy*x, ww-patch_size[0])
        for y in range(0, sy):
            ys = min(stride_xy * y,hh-patch_size[1])
            for z in range(0, sz):
                zs = min(stride_z * z, dd-patch_size[2])
                test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
                test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
                test_patch = torch.from_numpy(test_patch).to(next(net.parameters()).device.type)
                y1 = net(test_patch, perturbation=True)['out']
                y = F.softmax(y1, dim=1)
                y = y.cpu().data.numpy()
                y = y[0,:,:,:,:]
                score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                  = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
                cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                  = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
    score_map = score_map/np.expand_dims(cnt,axis=0)
    label_map = np.argmax(score_map, axis = 0)
    if add_pad:
        label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
        score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
    return score_map

def load_model(model_weights, in_channels, num_classes, backbone='VNet',device=torch.device("cuda")):
    model = deeplabv3.__dict__[backbone](in_channels=in_channels, out_channels=num_classes).to(device)
    print('#parameters:', sum(param.numel() for param in model.parameters()))
    model.load_state_dict(torch.load(model_weights, map_location='cpu'), strict=False)
    return model


if __name__=='__main__':
    args = get_args()

    # Init dataloader and model
    device = torch.device("cuda" if args.cuda else "cpu")
    dataset = LabelledDS(json_file=args.json_path, image_size=args.image_size, stage=args.json_key, is_augmentation=False)
    model = load_model(model_weights=args.model_weights, in_channels=args.in_channels, num_classes=args.num_classes)
    model = model.eval()

    # create destination folder
    os.makedirs(args.dest_path, exist_ok=True)
    with open(os.path.join(args.dest_path, 'args.txt'), 'w') as f:
        f.write(str(args))
    save_transform =  mt.SaveImageD(['image'], output_dir=args.dest_path, output_postfix='', separate_folder=False, output_ext='.nii.gz')

    L_lung = []
    L_liver = []
    for batch_fnames in dataset.data:
        batch = dataset.pre_transform(batch_fnames)
        batch = dataset.post_transform(batch)
        image = batch['image'] 
        image = image.unsqueeze(0).to(device)
        
        y_pred = test_single_case(net=model, image=image[0][0], stride_xy=args.stride_xy, stride_z=args.stride_z, patch_size=args.patch_size, num_classes=args.num_classes)
        y_pred = torch.argmax(torch.from_numpy(y_pred), dim=0).unsqueeze(0).to(device)
        batch['image'].set_array(y_pred[:,:,:,:].detach().cpu().numpy().astype(int))
        reversed_batch = dataset.post_transform.inverse(batch)
        reversed_batch = dataset.pre_transform.inverse(reversed_batch)
        save_transform(reversed_batch)

        if 'label' in batch_fnames.keys():
            label = batch['label']
            label = label.unsqueeze(0).to(device)
            dice_lung = DiceHelper(include_background=False, softmax=False)((y_pred==1).unsqueeze(0), label.as_tensor()==1)[0].item()
            dice_liver = DiceHelper(include_background=False, softmax=False)((y_pred==2).unsqueeze(0), label.as_tensor()==2)[0].item()
            L_lung.append(dice_lung)
            L_liver.append(dice_liver)
            print(dice_lung, dice_liver)
            print(os.path.basename(batch_fnames['image']))
        
            
    if 'label' in batch_fnames.keys():        
        print('mean liver:', sum(L_liver)/len(L_liver))
        print('mean lung:', sum(L_lung)/len(L_lung))
        # save file with results
        with open(os.path.join(args.dest_path, 'results.txt'), 'w') as f:
            f.write('mean liver: {}\nmean lung: {}'.format(sum(L_liver)/len(L_liver), sum(L_lung)/len(L_lung)))
