"""
These functions are modified from tomotwin.modules.inference.locator
"""
import sys
sys.path.append("..")
from utils.tomotwin import SIZE_DICT, _add_size

from tomotwin.modules.common.preprocess import label_filename
import numpy as np
import pandas as pd
import os

def _interval_overlap_vec(x1_min, x1_max, x2_min, x2_max):
    """
    tomotwin.modules.inference.locator
    """
    intersect = np.zeros(shape=(len(x1_min)))
    cond_a = x2_min < x1_min
    cond_b = cond_a & (x2_max >= x1_min)
    intersect[cond_b] = np.minimum(x1_max[cond_b], x2_max[cond_b]) - x1_min[cond_b]
    cond_c = ~cond_a & (x1_max >= x2_min)
    intersect[cond_c] = np.minimum(x1_max[cond_c], x2_max[cond_c]) - x2_min[cond_c]

    return intersect

def _bbox_iou_vec_3d(boxesA: np.array, boxesB: np.array) -> np.array:
    """
    tomotwin.modules.inference.locator
    """
    # 0 x
    # 1 y
    # 2 z
    # 3 w
    # 4 h
    # 5 depth

    x1_min = boxesA[:, 0] - boxesA[:, 3] / 2
    x1_max = boxesA[:, 0] + boxesA[:, 3] / 2
    y1_min = boxesA[:, 1] - boxesA[:, 4] / 2
    y1_max = boxesA[:, 1] + boxesA[:, 4] / 2
    z1_min = boxesA[:, 2] - boxesA[:, 5] / 2
    z1_max = boxesA[:, 2] + boxesA[:, 5] / 2

    x2_min = boxesB[:, 0] - boxesB[:, 3] / 2
    x2_max = boxesB[:, 0] + boxesB[:, 3] / 2
    y2_min = boxesB[:, 1] - boxesB[:, 4] / 2
    y2_max = boxesB[:, 1] + boxesB[:, 4] / 2
    z2_min = boxesB[:, 2] - boxesB[:, 5] / 2
    z2_max = boxesB[:, 2] + boxesB[:, 5] / 2

    intersect_w = _interval_overlap_vec(x1_min, x1_max, x2_min, x2_max)
    intersect_h = _interval_overlap_vec(y1_min, y1_max, y2_min, y2_max)
    intersect_depth = _interval_overlap_vec(z1_min, z1_max, z2_min, z2_max)
    intersect = intersect_w * intersect_h * intersect_depth
    union = boxesA[:, 3] * boxesA[:, 4] * boxesA[:, 5] + boxesB[:, 3] * boxesB[:, 4] * boxesB[:,
                                                                                        5] - intersect
    return intersect / union

def locate_positions_stats(locate_results, class_positions, iou_thresh):
    """
    tomotwin.scripts.evaluation
    """
    class_stats = {}
    locate_results_np =  locate_results[["X", "Y", "Z", "width", "height", "depth"]].to_numpy()
    true_positive = 0
    false_negative = 0
    found = np.array([False] * len(locate_results_np))
    for class_pos in class_positions.to_numpy():

        ones = np.ones((len(locate_results_np), 6))
        class_pos_rep = ones * class_pos
        ious = _bbox_iou_vec_3d(class_pos_rep, locate_results_np)
        iou_mask = ious > iou_thresh

        # if np.count_nonzero(iou_mask) >= 2:
        #     import inspect
        #     callerframerecord = inspect.stack()[1]
        #     frame = callerframerecord[0]
        #     info = inspect.getframeinfo(frame)
            #print(f"{np.count_nonzero(iou_mask)} Maxima?? WAIT WHAT? oO")
        if np.any(iou_mask):

            found[np.argmax(ious)] = True

            true_positive = true_positive + 1
        else:
            false_negative = false_negative + 1
    false_positive = np.sum(np.array(found) == False)
    true_positive_rate = true_positive / len(class_positions)

    recall = true_positive / (true_positive + false_negative)
    precision = true_positive / (true_positive + false_positive)
    if precision == 0 and recall == 0:
        f1_score = 0
    else:
        f1_score = np.nan_to_num(2 * precision * recall / (precision + recall))
    class_stats["F1"] = float(f1_score)
    class_stats["Recall"] = recall
    class_stats["Precision"] = float(precision)
    class_stats["TruePositiveRate"] = float(true_positive_rate)
    class_stats["TP"] = int(true_positive)
    class_stats["FP"] = int(false_positive)
    class_stats["FN"] = int(false_negative)
    return class_stats

def _filter(df, min_val=None, max_val=None, field=None):
    """
    tomotwin.scripts.evaluation
    """
    dfc = df.copy()
    if field == None:
        return dfc

    if min_val != None:
        dfc = dfc[dfc[field] > min_val]

    if max_val != None:
        dfc = dfc[dfc[field] < max_val]
    return dfc


def get_stats(df, positions, iou_thresh=0.6):
    """
    tomotwin.scripts.evaluation
    """
    #label_filename()

    refs = df.attrs['references']
    pc = df['predicted_class'].iloc[0]
    class_name = os.path.splitext(refs[pc])[0]
    class_name = label_filename(class_name)

    pos_classes = np.array([cl.upper() for cl in positions["class"]])
    class_positions = positions[pos_classes == class_name.upper()]
    class_positions = class_positions[["X", "Y", "Z", "width", "height", "depth"]]
    # locate_results.columns = [["X", "Y", "Z", "class"]]

    df = df.rename(columns={"predicted_class_name": "class"})
    df["class"] = class_name
    df = _add_size(df, size=37, size_dict=SIZE_DICT)

    stats = locate_positions_stats(locate_results=df, class_positions=class_positions, iou_thresh=iou_thresh)
    return stats

def optim(locate_results, positions, min_size_range=[1, 500], max_size_range=[1, 500], min_size_step=2, max_size_step=2):
    """
    tomotwin.scripts.evaluation
    """
    def find_best(locate_results, field, range, stepsize, type):

        best_stats = get_stats(locate_results, positions)
        #print(best_stats)
        #import sys
        #sys.exit()
        best_f1 = best_stats["F1"]
        best_value = 0
        best_df = locate_results

        for val in np.arange(start=range[0], stop=range[1], step=stepsize):
            if type == "min":
                df = _filter(locate_results, min_val=val, field=field)
            if type == "max":
                df = _filter(locate_results, max_val=val, field=field)
            if len(df) == 0:
                continue
            stats = get_stats(df, positions)
            if stats["F1"] > best_f1:
                best_f1 = stats["F1"]
                best_stats = stats
                best_value = val
                best_df = df.copy()
        return best_stats, best_df, best_value

    # min_size_range = [1, 500]
    # max_size_range = [1, 500]
    # dsize = 2
    # min_similarity_range = [0,1]
    # dsim = self.stepsize_optim_similarity
    locate_results_id = locate_results
    o_dict = {}
    # stats, locate_results_filtered, best_value = find_best(
    #     locate_results=locate_results_id,
    #     field="metric_best",
    #     range=min_similarity_range,
    #     stepsize=dsim,
    #     type="min"
    # )
    # if locate_results_filtered is not None:
    #     o_dict["O_METRIC"] = float(best_value)
    #     locate_results_id = locate_results_filtered

    stats, locate_results_filtered, best_value = find_best(
        locate_results=locate_results_id,
        field="size",
        range=min_size_range,
        stepsize=min_size_step,
        type="min"
    )
    if locate_results_filtered is not None:
        o_dict["O_MIN_SIZE"] = int(best_value)
        locate_results_id = locate_results_filtered

    stats, locate_results_filtered, best_value = find_best(
        locate_results=locate_results_id,
        field="size",
        range=max_size_range,
        stepsize=max_size_step,
        type="max"
    )
    if locate_results_filtered is not None:
        o_dict["O_MAX_SIZE"] = int(best_value)
        locate_results_id = locate_results_filtered

    stats.update(o_dict)

    return stats