import math

import numpy as np
import torch
import torch.nn as nn

def mean_precision(eval_segm, gt_segm):
    check_size(eval_segm, gt_segm)
    cl, n_cl = extract_classes(gt_segm)
    eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl)
    mAP = [0] * n_cl
    for i, c in enumerate(cl):
        curr_eval_mask = eval_mask[i, :, :]
        curr_gt_mask = gt_mask[i, :, :]
        n_ii = np.sum(np.logical_and(curr_eval_mask, curr_gt_mask))
        n_ij = np.sum(curr_eval_mask)
        val = n_ii / float(n_ij)
        if math.isnan(val):
            mAP[i] = 0.
        else:
            mAP[i] = val
    # print(mAP)
    return mAP


def mean_IU(eval_segm, gt_segm):
    '''
    (1/n_cl) * sum_i(n_ii / (t_i + sum_j(n_ji) - n_ii))
    '''

    check_size(eval_segm, gt_segm)

    cl, n_cl = union_classes(eval_segm, gt_segm)
    _, n_cl_gt = extract_classes(gt_segm)
    eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl)

    IU = list([0]) * n_cl

    for i, c in enumerate(cl):
        curr_eval_mask = eval_mask[i, :, :]
        curr_gt_mask = gt_mask[i, :, :]

        if (np.sum(curr_eval_mask) == 0) or (np.sum(curr_gt_mask) == 0):
            continue

        n_ii = np.sum(np.logical_and(curr_eval_mask, curr_gt_mask))
        t_i = np.sum(curr_gt_mask)
        n_ij = np.sum(curr_eval_mask)

        IU[i] = n_ii / (t_i + n_ij - n_ii)

    return IU


'''
Auxiliary functions used during evaluation.
'''


def get_pixel_area(segm):
    return segm.shape[0] * segm.shape[1]


def extract_both_masks(eval_segm, gt_segm, cl, n_cl):
    eval_mask = extract_masks(eval_segm, cl, n_cl)
    gt_mask = extract_masks(gt_segm, cl, n_cl)

    return eval_mask, gt_mask


def extract_classes(segm):
    cl = np.unique(segm)
    n_cl = len(cl)

    return cl, n_cl


def union_classes(eval_segm, gt_segm):
    eval_cl, _ = extract_classes(eval_segm)
    gt_cl, _ = extract_classes(gt_segm)

    cl = np.union1d(eval_cl, gt_cl)
    n_cl = len(cl)

    return cl, n_cl


def extract_masks(segm, cl, n_cl):
    h, w = segm_size(segm)
    masks = np.zeros((n_cl, h, w))

    for i, c in enumerate(cl):
        masks[i, :, :] = segm == c

    return masks


def segm_size(segm):
    try:
        height = segm.shape[0]
        width = segm.shape[1]
    except IndexError:
        raise

    return height, width


def check_size(eval_segm, gt_segm):
    h_e, w_e = segm_size(eval_segm)
    h_g, w_g = segm_size(gt_segm)

    if (h_e != h_g) or (w_e != w_g):
        raise EvalSegErr("DiffDim: Different dimensions of matrices!")


def cal_iou_training_previous(batch_dict, output_dict):
    """
    Calculate IoU during training.

    Parameters
    ----------
    batch_dict: dict
        The data that contains the gt.

    output_dict : dict
        The output directory with predictions.

    Returns
    -------
    The iou for static and dynamic bev map.
    """

    batch_size = batch_dict['ego']['gt_static'].shape[0]

    for i in range(batch_size):

        gt_static = \
            batch_dict['ego']['gt_static'].detach().cpu().data.numpy()[i, 0]
        gt_static = np.array(gt_static, dtype=np.int)

        gt_dynamic = \
            batch_dict['ego']['gt_dynamic'].detach().cpu().data.numpy()[i, 0]
        gt_dynamic = np.array(gt_dynamic, dtype=np.int)

        pred_static = \
            output_dict['static_map'].detach().cpu().data.numpy()[i]
        pred_static = np.array(pred_static, dtype=np.int)

        pred_dynamic = \
            output_dict['dynamic_map'].detach().cpu().data.numpy()[i]
        pred_dynamic = np.array(pred_dynamic, dtype=np.int)

        iou_dynamic = mean_IU(pred_dynamic, gt_dynamic)
        iou_static = mean_IU(pred_static, gt_static)

        return iou_dynamic, iou_static

def cal_iou_training(batch_dict, output_dict):
    """
    Calculate IoU during training.

    Parameters
    ----------
    batch_dict: dict
        The data that contains the gt.

    output_dict : dict
        The output directory with predictions.

    Returns
    -------
    The iou for static and dynamic bev map.
    """
    softmax_func = nn.Softmax(dim=1)
    batch_size = batch_dict['ego']['label_dict']['label_map'].shape[0]

    for i in range(batch_size):

        # pred_dynamic = output_dict['dynamic_map'].detach().cpu().data.numpy()[i]
        # pred_dynamic = np.array(pred_dynamic, dtype=np.int)
        gt_map = batch_dict['ego']['label_dict']['label_map'].detach().cpu().data.numpy()[i][0]
        seg_logits = output_dict['seg']
        seg_probs = softmax_func(seg_logits)

        output_map = torch.argmax(seg_probs, dim=1)

        pred_map = output_map.detach().cpu().data.numpy()[i]

        iou_res = mean_IU(pred_map, gt_map)

        try:
            res=iou_res[1] # If pred_map and gt_map are all zero
        except:
            res = 0
        return res, pred_map, gt_map

def cal_iou_training_bev(output_dict, target_map):
    softmax_func = nn.Softmax(dim=1)
    batch_size = target_map.shape[0]

    for i in range(batch_size):
        gt_map = target_map.to(int).detach().cpu().data.numpy()[i]
        seg_logits = output_dict['seg']
        seg_probs = softmax_func(seg_logits)

        output_map = torch.argmax(seg_probs, dim=1)
        pred_map = output_map.detach().cpu().data.numpy()[i]

        iou_res = mean_IU(pred_map, gt_map)

        try:
            res=iou_res[1] # If pred_map and gt_map are all zero
        except:
            res = 0
        return res

def normalize_maps(output_dict, target_map):
    softmax_func = nn.Softmax(dim=1)
    batch_size = target_map.shape[0]

    for i in range(batch_size):
        gt_map = target_map.to(int).detach().cpu().data.numpy()[i]
        seg_logits = output_dict['seg']
        seg_probs = softmax_func(seg_logits)

        output_map = torch.argmax(seg_probs, dim=1)
        pred_map = output_map.detach().cpu().data.numpy()[i]

        return pred_map, gt_map

class EvalSegErr(Exception):
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)
