import math

import torch
import torch.nn.functional as F
import numpy as np
import medpy.metric.binary as mmb
import SimpleITK as sitk


def test_single_case(net, image, stride, patch_size, num_classes=1):
    """
    predict 3d volume using slide window

    Parameters
    ----------
        net : model
        image : must be 3d array,shape [C,W,H,D]
        stride : tuple / List
        patch_size : tuple / List
        num_classes : number of class

    Returns
    -------
    label_map : prediction, shape is the same as image
    score_map : softmax outputs, shape [C,*]
    """
    _, w, h, d = image.shape

    # if the size of image is less than patch_size, then padding it
    add_pad = False
    if w < patch_size[0]:
        w_pad = patch_size[0] - w
        add_pad = True
    else:
        w_pad = 0
    if h < patch_size[1]:
        h_pad = patch_size[1] - h
        add_pad = True
    else:
        h_pad = 0
    if d < patch_size[2]:
        d_pad = patch_size[2] - d
        add_pad = True
    else:
        d_pad = 0
    wl_pad, wr_pad = w_pad // 2, w_pad - w_pad // 2
    hl_pad, hr_pad = h_pad // 2, h_pad - h_pad // 2
    dl_pad, dr_pad = d_pad // 2, d_pad - d_pad // 2
    if add_pad:
        image = np.pad(
            image,
            [(0, 0), (wl_pad, wr_pad), (hl_pad, hr_pad), (dl_pad, dr_pad)],
            mode="constant",
            constant_values=0,
        )
    _, ww, hh, dd = image.shape

    sx = math.ceil((ww - patch_size[0]) / stride[0]) + 1
    sy = math.ceil((hh - patch_size[1]) / stride[1]) + 1
    sz = math.ceil((dd - patch_size[2]) / stride[2]) + 1
    score_map = np.zeros((num_classes,) + (ww, hh, dd)).astype(np.float32)
    cnt = np.zeros((ww, hh, dd)).astype(np.float32)

    for x in range(0, sx):
        xs = min(stride[0] * x, ww - patch_size[0])
        for y in range(0, sy):
            ys = min(stride[1] * y, hh - patch_size[1])
            for z in range(0, sz):
                zs = min(stride[2] * z, dd - patch_size[2])
                test_patch = image[
                    :,
                    xs : xs + patch_size[0],
                    ys : ys + patch_size[1],
                    zs : zs + patch_size[2],
                ]
                test_patch = np.expand_dims(test_patch, axis=0).astype(np.float32)
                test_patch = torch.from_numpy(test_patch).cuda()
                _, y1 = net(test_patch)
                y = F.softmax(y1, dim=1)
                y = y.cpu().data.numpy()
                y = y[0, :, :, :, :]
                score_map[
                    :,
                    xs : xs + patch_size[0],
                    ys : ys + patch_size[1],
                    zs : zs + patch_size[2],
                ] = (
                    score_map[
                        :,
                        xs : xs + patch_size[0],
                        ys : ys + patch_size[1],
                        zs : zs + patch_size[2],
                    ]
                    + y
                )
                cnt[
                    xs : xs + patch_size[0],
                    ys : ys + patch_size[1],
                    zs : zs + patch_size[2],
                ] = (
                    cnt[
                        xs : xs + patch_size[0],
                        ys : ys + patch_size[1],
                        zs : zs + patch_size[2],
                    ]
                    + 1
                )
    score_map = score_map / np.expand_dims(cnt, axis=0)
    label_map = np.argmax(score_map, axis=0)
    if add_pad:
        label_map = label_map[
            wl_pad : wl_pad + w, hl_pad : hl_pad + h, dl_pad : dl_pad + d
        ]
        score_map = score_map[
            :, wl_pad : wl_pad + w, hl_pad : hl_pad + h, dl_pad : dl_pad + d
        ]
    return label_map, score_map


def decode_label(label):
    """
    Convert multi-label to region label
    label, 1(ET),2(NET),3(ED)
    """
    label = label.copy()
    wt = (label != 0) * 1
    tc = (label == 1) * 1 + (label == 2) * 1
    ec = (label == 1) * 1
    return wt, tc, ec


def eval_dice(pred, label):
    """
    dice to eval
    """
    if np.sum(pred) == 0 and np.sum(label) == 0:
        return 1
    else:
        return mmb.dc(pred, label)


def eval_one_dice(pred, label):
    """
    for validation
    """
    pred_data_wt, pred_data_co, pred_data_ec = decode_label(pred)
    gt_data_wt, gt_data_co, gt_data_ec = decode_label(label)

    dice_wt = eval_dice(pred_data_wt, gt_data_wt)
    dice_co = eval_dice(pred_data_co, gt_data_co)
    dice_ec = eval_dice(pred_data_ec, gt_data_ec)
    dice_mean = (dice_wt + dice_co + dice_ec) / 3.0

    return dice_wt, dice_co, dice_ec, dice_mean


def evaluate_one_case(pred, label):
    """evaluate one case
    metric : dice hd sensitivity specificity
    dice: (2*TP)/(FP+2*TP+FN)
    sensitivity : TP/(TP+FN)
    specificity : TN/(TN+FP)
    """
    pred_data_wt, pred_data_co, pred_data_ec = decode_label(pred)
    gt_data_wt, gt_data_co, gt_data_ec = decode_label(label)

    if np.sum(pred_data_wt) > 0 and np.sum(gt_data_wt) > 0:
        hd_wt = mmb.hd95(pred_data_wt, gt_data_wt)
    else:
        hd_wt = np.nan

    if np.sum(pred_data_co) > 0 and np.sum(gt_data_co) > 0:
        hd_co = mmb.hd95(pred_data_co, gt_data_co)
    else:
        hd_co = np.nan

    if np.sum(pred_data_ec) > 0 and np.sum(gt_data_ec) > 0:
        hd_ec = mmb.hd95(pred_data_ec, gt_data_ec)
    else:
        hd_ec = np.nan

    hd = [hd_wt, hd_co, hd_ec]

    dice_wt = eval_dice(pred_data_wt, gt_data_wt)
    dice_co = eval_dice(pred_data_co, gt_data_co)
    dice_ec = eval_dice(pred_data_ec, gt_data_ec)

    dice = [dice_wt, dice_co, dice_ec]

    sensitivity_wt = mmb.sensitivity(pred_data_wt, gt_data_wt)
    sensitivity_co = mmb.sensitivity(pred_data_co, gt_data_co)
    sensitivity_ec = mmb.sensitivity(pred_data_ec, gt_data_ec)

    sensitivity = [sensitivity_wt, sensitivity_co, sensitivity_ec]

    specificity_wt = mmb.specificity(pred_data_wt, gt_data_wt)
    specificity_co = mmb.specificity(pred_data_co, gt_data_co)
    specificity_ec = mmb.specificity(pred_data_ec, gt_data_ec)

    specificity = [specificity_wt, specificity_co, specificity_ec]

    return hd, dice, sensitivity, specificity


def convert_to_sitk(arr, output_path, modality=None):
    if modality is not None:
        for i in modality:
            itkimg = sitk.GetImageFromArray(arr[i])
            sitk.WriteImage(itkimg, output_path + "/image_%d.nii.gz" % i)
    else:
        itkimg = sitk.GetImageFromArray(arr)
        sitk.WriteImage(itkimg, output_path)
