import os
import cv2
import argparse
import numpy as np
from scipy.ndimage import convolve, binary_erosion
from scipy.ndimage import distance_transform_edt as bwdist
from skimage import morphology
from tqdm import tqdm

# Constants
EPS = np.spacing(1)
TYPE = np.float64


def get_adaptive_threshold(matrix: np.ndarray, max_value: float = 1) -> float:
    return min(2 * matrix.mean(), max_value)


def normalize_input(pred: np.ndarray, gt: np.ndarray):
    """Normalize inputs: resize, scale to [0,1], and binarize GT."""
    if pred.shape != gt.shape:
        pred = cv2.resize(pred, gt.shape[::-1], interpolation=cv2.INTER_LINEAR)

    pred = pred.astype(TYPE)
    pred_min, pred_max = pred.min(), pred.max()

    if pred_max != pred_min:
        pred = (pred - pred_min) / (pred_max - pred_min + EPS)
    else:
        pred = pred / 255.0 if pred_max > 1 else pred

    gt = gt > 127 if gt.max() > 1 else gt > 0.5
    return pred, gt


def compute_surface_distances(result, reference, voxelspacing=None):
    """
    Compute surface distances using SciPy's distance_transform_edt.
    This serves as the core backend for HD95 and ASD.
    """
    if voxelspacing is None:
        voxelspacing = [1] * result.ndim

    result = np.atleast_1d(result.astype(bool))
    reference = np.atleast_1d(reference.astype(bool))

    # Surface extraction via XOR with eroded image
    border_result = result ^ binary_erosion(result, structure=np.ones((3, 3)))
    border_reference = reference ^ binary_erosion(reference, structure=np.ones((3, 3)))

    # Distance transform
    dt_ref_border = bwdist(~border_reference, sampling=voxelspacing)
    dt_res_border = bwdist(~border_result, sampling=voxelspacing)

    distances_r2t = dt_ref_border[border_result]
    distances_t2r = dt_res_border[border_reference]

    return np.concatenate([distances_r2t, distances_t2r])


class MAE:
    def __init__(self):
        self.maes = []

    def step(self, pred: np.ndarray, gt: np.ndarray):
        pred, gt = normalize_input(pred, gt)
        mae = np.mean(np.abs(pred - gt.astype(TYPE)))
        self.maes.append(mae)

    def get_results(self):
        return {"MAE": np.mean(self.maes)}


class Fmeasure:
    def __init__(self, beta: float = 0.3):
        self.beta = beta
        self.adaptive_fms = []
        self.changeable_fms = []

    def step(self, pred: np.ndarray, gt: np.ndarray):
        pred, gt = normalize_input(pred, gt)

        # Adaptive Threshold
        adp_thresh = get_adaptive_threshold(pred, max_value=1)
        bin_pred = pred > adp_thresh
        tp = bin_pred[gt].sum()

        if tp == 0:
            adp_fm = 0
        else:
            pre = tp / (np.count_nonzero(bin_pred) + EPS)
            rec = tp / (np.count_nonzero(gt) + EPS)
            adp_fm = (1 + self.beta) * pre * rec / (self.beta * pre + rec + EPS)
        self.adaptive_fms.append(adp_fm)

        # Curve (Histogram based)
        pred_u8 = (pred * 255).astype(np.uint8)
        bins = np.linspace(0, 256, 257)
        fg_hist, _ = np.histogram(pred_u8[gt], bins=bins)
        bg_hist, _ = np.histogram(pred_u8[~gt], bins=bins)

        tp_cum = np.cumsum(np.flip(fg_hist))
        fp_cum = np.cumsum(np.flip(bg_hist))

        precisions = tp_cum / (tp_cum + fp_cum + EPS)
        recalls = tp_cum / (np.count_nonzero(gt) + EPS)
        fms = (1 + self.beta) * precisions * recalls / (self.beta * precisions + recalls + EPS)
        self.changeable_fms.append(fms)

    def get_results(self):
        fm_curve = np.mean(self.changeable_fms, axis=0)
        return {
            "Fm_max": fm_curve.max(),
            "Fm_mean": fm_curve.mean(),
            "Fm_adp": np.mean(self.adaptive_fms)
        }


class WeightedFmeasure:
    def __init__(self, beta: float = 1):
        self.beta = beta
        self.weighted_fms = []

    def matlab_style_gauss2D(self, shape=(7, 7), sigma=5):
        m, n = [(ss - 1) / 2 for ss in shape]
        y, x = np.ogrid[-m: m + 1, -n: n + 1]
        h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
        h[h < np.finfo(h.dtype).eps * h.max()] = 0
        sumh = h.sum()
        if sumh != 0: h /= sumh
        return h

    def step(self, pred: np.ndarray, gt: np.ndarray):
        pred, gt = normalize_input(pred, gt)
        if not np.any(gt):
            self.weighted_fms.append(0)
            return

        Dst, Idxt = bwdist(~gt, return_indices=True)
        gt_float = gt.astype(TYPE)
        E = np.abs(pred - gt_float)

        Et = np.copy(E)
        Et[~gt] = Et[Idxt[0][~gt], Idxt[1][~gt]]

        K = self.matlab_style_gauss2D((7, 7), sigma=5)
        EA = convolve(Et, weights=K, mode="constant", cval=0)
        MIN_E_EA = np.where(gt & (EA < E), EA, E)

        B = np.where(~gt, 2 - np.exp(np.log(0.5) / 5 * Dst), np.ones_like(gt_float))
        Ew = MIN_E_EA * B

        TPw = np.sum(gt_float) - np.sum(Ew[gt])
        FPw = np.sum(Ew[~gt])

        R = 1 - np.mean(Ew[gt])
        P = TPw / (TPw + FPw + EPS)
        Q = (1 + self.beta) * R * P / (R + self.beta * P + EPS)
        self.weighted_fms.append(Q)

    def get_results(self):
        return {"wFm": np.mean(self.weighted_fms)}


class Emeasure:
    def __init__(self):
        self.adaptive_ems = []
        self.changeable_ems = []

    def step(self, pred: np.ndarray, gt: np.ndarray):
        pred, gt = normalize_input(pred, gt)
        self.gt_fg_numel = np.count_nonzero(gt)
        self.gt_size = gt.size

        # Adaptive
        adp_thresh = get_adaptive_threshold(pred, max_value=1)
        self.adaptive_ems.append(self.cal_em_with_threshold(pred, gt, adp_thresh))

        # Curve
        pred_u8 = (pred * 255).astype(np.uint8)
        bins = np.linspace(0, 256, 257)
        fg_fg_hist, _ = np.histogram(pred_u8[gt], bins=bins)
        fg_bg_hist, _ = np.histogram(pred_u8[~gt], bins=bins)

        fg_fg_cum = np.cumsum(np.flip(fg_fg_hist))
        fg_bg_cum = np.cumsum(np.flip(fg_bg_hist))
        self.changeable_ems.append(self.cal_em_curve(fg_fg_cum, fg_bg_cum))

    def cal_em_with_threshold(self, pred, gt, threshold):
        bin_pred = pred > threshold
        fg_fg = np.count_nonzero(bin_pred & gt)
        fg_bg = np.count_nonzero(bin_pred & ~gt)
        return self._calc_em_core(fg_fg, fg_bg)

    def cal_em_curve(self, fg_fg_list, fg_bg_list):
        return [self._calc_em_core(f, b) for f, b in zip(fg_fg_list, fg_bg_list)]

    def _calc_em_core(self, fg_fg, fg_bg):
        fg_total = fg_fg + fg_bg
        bg_fg = self.gt_fg_numel - fg_fg
        bg_bg = (self.gt_size - fg_total) - bg_fg

        if self.gt_fg_numel == 0: return (bg_fg + bg_bg) / (self.gt_size - 1 + EPS)
        if self.gt_fg_numel == self.gt_size: return (fg_fg + fg_bg) / (self.gt_size - 1 + EPS)

        parts = [fg_fg, fg_bg, bg_fg, bg_bg]
        mean_pred = fg_total / self.gt_size
        mean_gt = self.gt_fg_numel / self.gt_size

        dem_pred = [1 - mean_pred, 1 - mean_pred, 0 - mean_pred, 0 - mean_pred]
        dem_gt = [1 - mean_gt, 0 - mean_gt, 1 - mean_gt, 0 - mean_gt]

        score = 0
        for i in range(4):
            align = 2 * (dem_pred[i] * dem_gt[i]) / (dem_pred[i] ** 2 + dem_gt[i] ** 2 + EPS)
            enhanced = (align + 1) ** 2 / 4
            score += enhanced * parts[i]
        return score / (self.gt_size - 1 + EPS)

    def get_results(self):
        em_curve = np.mean(self.changeable_ems, axis=0)
        return {
            "Em_adp": np.mean(self.adaptive_ems),
            "Em_mean": em_curve.mean(),
            "Em_max": em_curve.max()
        }


class SmeasureStrict:
    def __init__(self, alpha: float = 0.5):
        self.sms = []
        self.alpha = alpha

    def step(self, pred: np.ndarray, gt: np.ndarray):
        pred, gt = normalize_input(pred, gt)
        y = np.mean(gt)
        if y == 0:
            sm = 1 - np.mean(pred)
        elif y == 1:
            sm = np.mean(pred)
        else:
            sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt)
            sm = max(0, sm)
        self.sms.append(sm)

    def ssim(self, pred, gt):
        h, w = pred.shape
        N = h * w
        x, y = np.mean(pred), np.mean(gt)
        sigma_x = np.sum((pred - x) ** 2) / (N - 1 + EPS)
        sigma_y = np.sum((gt - y) ** 2) / (N - 1 + EPS)
        sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1 + EPS)
        alpha = 4 * x * y * sigma_xy
        beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y)
        return alpha / (beta + EPS) if alpha != 0 else (1 if alpha == 0 and beta == 0 else 0)

    def object(self, pred, gt):
        gt_mean = np.mean(gt)
        fg = pred[gt]
        bg = (1 - pred)[~gt]

        def s_obj(x):
            m = np.mean(x)
            s = np.std(x, ddof=1)
            return 2 * m / (m ** 2 + 1 + s + EPS)

        return s_obj(fg) * gt_mean + s_obj(bg) * (1 - gt_mean)

    def region(self, pred, gt):
        h, w = gt.shape
        if np.count_nonzero(gt) == 0:
            cy, cx = h // 2, w // 2
        else:
            cy, cx = np.argwhere(gt).mean(axis=0).round().astype(int)

        cx, cy = cx + 1, cy + 1
        gt_ul, gt_ur = gt[0:cy, 0:cx], gt[0:cy, cx:w]
        gt_bl, gt_br = gt[cy:h, 0:cx], gt[cy:h, cx:w]
        pred_ul, pred_ur = pred[0:cy, 0:cx], pred[0:cy, cx:w]
        pred_bl, pred_br = pred[cy:h, 0:cx], pred[cy:h, cx:w]

        total_fg = np.count_nonzero(gt) + EPS
        w_ul, w_ur = np.count_nonzero(gt_ul) / total_fg, np.count_nonzero(gt_ur) / total_fg
        w_bl, w_br = np.count_nonzero(gt_bl) / total_fg, np.count_nonzero(gt_br) / total_fg

        return (w_ul * self.ssim(pred_ul, gt_ul) + w_ur * self.ssim(pred_ur, gt_ur) +
                w_bl * self.ssim(pred_bl, gt_bl) + w_br * self.ssim(pred_br, gt_br))

    def get_results(self):
        return {"Sm": np.mean(self.sms)}


class Tmeasure:
    def __init__(self, alpha: float = 0.8):
        self.scores = []
        self.alpha = alpha

    def _get_adaptive_threshold(self, img):
        return min(2 * img.mean(), 0.95)

    def _cal_soft_iou(self, pred, gt):
        gt_float = gt.astype(TYPE)
        inter = pred * gt_float
        union = pred + gt_float - inter
        return np.sum(inter) / (np.sum(union) + EPS)

    def _cal_connectivity(self, pred, gt):
        th = self._get_adaptive_threshold(pred)
        bin_pred = (pred > th)

        # Connected Component Analysis
        num_gt, labels_gt, stats_gt, _ = cv2.connectedComponentsWithStats(gt.astype(np.uint8), connectivity=8)
        if num_gt <= 1: return 1.0 if bin_pred.sum() == 0 else 0.0

        score_sum = 0
        gt_total = np.sum(gt)

        for i in range(1, num_gt):
            x, y, w, h, area = stats_gt[i]
            # ROI Optimization
            gt_mask_roi = (labels_gt[y:y + h, x:x + w] == i)
            pred_in_region_roi = bin_pred[y:y + h, x:x + w] & gt_mask_roi

            inter_area = np.sum(pred_in_region_roi)
            if inter_area == 0:
                ratio = 0
            else:
                region_u8 = pred_in_region_roi.astype(np.uint8)
                num_labels, _, stats, _ = cv2.connectedComponentsWithStats(region_u8, connectivity=8)
                if num_labels <= 1:
                    ratio = 0
                else:
                    max_area = stats[1:, cv2.CC_STAT_AREA].max()
                    ratio = max_area / inter_area
            score_sum += ratio * (area / gt_total)
        return score_sum

    def _cal_boundary(self, pred, gt):
        h, w = gt.shape
        diagonal = np.sqrt(h ** 2 + w ** 2)
        tolerance_r = int(min(max(1, diagonal * 0.005), 5))

        # Gradient computation
        grad_mag_p = np.linalg.norm([cv2.Sobel(pred, cv2.CV_64F, 1, 0, ksize=3),
                                     cv2.Sobel(pred, cv2.CV_64F, 0, 1, ksize=3)], axis=0)
        grad_mag_g = np.linalg.norm([cv2.Sobel(gt.astype(TYPE), cv2.CV_64F, 1, 0, ksize=3),
                                     cv2.Sobel(gt.astype(TYPE), cv2.CV_64F, 0, 1, ksize=3)], axis=0)

        # Bidirectional Tolerance Masks
        mask_gt_tolerance = cv2.dilate((grad_mag_g > 0.5).astype(np.uint8),
                                       morphology.disk(tolerance_r), iterations=1) > 0
        mask_pred_tolerance = cv2.dilate((grad_mag_p > 0.5).astype(np.uint8),
                                         morphology.disk(tolerance_r), iterations=1) > 0

        tp_precision = np.sum(grad_mag_p[mask_gt_tolerance])
        total_pred_energy = np.sum(grad_mag_p)
        tp_recall = np.sum(grad_mag_g[mask_pred_tolerance])
        total_gt_energy = np.sum(grad_mag_g)

        if total_pred_energy == 0 and total_gt_energy == 0: return 1.0
        if total_pred_energy == 0 or total_gt_energy == 0: return 0.0

        precision = tp_precision / (total_pred_energy + EPS)
        recall = tp_recall / (total_gt_energy + EPS)
        return 2 * precision * recall / (precision + recall + EPS)

    def step(self, pred: np.ndarray, gt: np.ndarray):
        pred, gt = normalize_input(pred, gt)
        iou = self._cal_soft_iou(pred, gt)
        conn = self._cal_connectivity(pred, gt)
        bnd = self._cal_boundary(pred, gt)
        tm = iou * (self.alpha * conn + (1 - self.alpha) * bnd)
        self.scores.append(tm)

    def get_results(self):
        return {"Tm": np.mean(self.scores)}


class IoU:
    def __init__(self): self.ious = []

    def step(self, pred, gt):
        pred, gt = normalize_input(pred, gt)
        inter = np.logical_and(pred > 0.5, gt > 0.5).sum()
        union = np.logical_or(pred > 0.5, gt > 0.5).sum()
        self.ious.append(1.0 if union == 0 else inter / (union + EPS))

    def get_results(self): return {"IoU": np.mean(self.ious)}


class Dice:
    def __init__(self): self.dices = []

    def step(self, pred, gt):
        pred, gt = normalize_input(pred, gt)
        inter = np.logical_and(pred > 0.5, gt > 0.5).sum()
        total = (pred > 0.5).sum() + (gt > 0.5).sum()
        self.dices.append(1.0 if total == 0 else 2 * inter / (total + EPS))

    def get_results(self): return {"Dice": np.mean(self.dices)}


class HD95:
    def __init__(self):
        self.hd95s = []

    def step(self, pred, gt):
        pred, gt = normalize_input(pred, gt)
        bin_pred, bin_gt = pred > 0.5, gt > 0.5
        if bin_pred.sum() == 0 and bin_gt.sum() == 0:
            self.hd95s.append(0.0)
        elif bin_pred.sum() == 0 or bin_gt.sum() == 0:
            self.hd95s.append(np.nan)
        else:
            try:
                sd = compute_surface_distances(bin_pred, bin_gt)
                self.hd95s.append(0.0 if len(sd) == 0 else np.percentile(sd, 95))
            except:
                self.hd95s.append(np.nan)

    def get_results(self):
        clean = [x for x in self.hd95s if not np.isnan(x)]
        return {"95HD": np.mean(clean) if clean else np.nan}


class ASD:
    def __init__(self):
        self.asds = []

    def step(self, pred, gt):
        pred, gt = normalize_input(pred, gt)
        bin_pred, bin_gt = pred > 0.5, gt > 0.5
        if bin_pred.sum() == 0 and bin_gt.sum() == 0:
            self.asds.append(0.0)
        elif bin_pred.sum() == 0 or bin_gt.sum() == 0:
            self.asds.append(np.nan)
        else:
            try:
                sd = compute_surface_distances(bin_pred, bin_gt)
                self.asds.append(0.0 if len(sd) == 0 else np.mean(sd))
            except:
                self.asds.append(np.nan)

    def get_results(self):
        clean = [x for x in self.asds if not np.isnan(x)]
        return {"ASD": np.mean(clean) if clean else np.nan}


class Evaluator:
    def __init__(self, gt_root, pred_root):
        self.gt_root = gt_root
        self.pred_root = pred_root
        self.metrics = [MAE(), Fmeasure(), WeightedFmeasure(), Emeasure(),
                        SmeasureStrict(), Tmeasure(), IoU(), Dice(), HD95(), ASD()]

    def evaluate(self):
        if not os.path.exists(self.gt_root) or not os.path.exists(self.pred_root):
            raise FileNotFoundError("Input paths do not exist.")

        names = [f for f in os.listdir(self.gt_root) if f.lower().endswith(('.png', '.jpg', '.bmp', '.tif'))]
        print(f"Starting evaluation on {len(names)} images...")

        for name in tqdm(names):
            gt_path = os.path.join(self.gt_root, name)
            # Support various extensions for predictions
            pred_name = name
            pred_path = os.path.join(self.pred_root, pred_name)
            if not os.path.exists(pred_path):
                pred_name = os.path.splitext(name)[0] + '.png'
                pred_path = os.path.join(self.pred_root, pred_name)

            if not os.path.exists(pred_path): continue

            gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
            pred = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)

            if gt is None or pred is None: continue

            for m in self.metrics:
                m.step(pred, gt)

        results = {}
        for m in self.metrics:
            results.update(m.get_results())
        return results


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Benchmarking Script for Segmentation Metrics")
    parser.add_argument('--gt', type=str, required=True, help='Path to Ground Truth folder')
    parser.add_argument('--pred', type=str, required=True, help='Path to Prediction folder')
    args = parser.parse_args()

    evaluator = Evaluator(args.gt, args.pred)
    final_results = evaluator.evaluate()

    if final_results:
        print("\n" + "=" * 25 + " Final Results " + "=" * 25)
        print(f"{'Metric':<12} | {'Score':<10}")
        for k, v in final_results.items():
            print(f"{k:<12} | {v:.4f}")
        print("=" * 46)