import torch
from torch import Tensor
from torchmetrics import Metric

from . import analytics
from .epfl import (
    compute_3d_efficiency,
    compute_axial_efficiency,
    compute_lat_efficiency,
)
from .matching import match
from .rmse import compute_rmse


def _compute_sample_metrics(x: Tensor, x_gt: Tensor, s: Tensor):
    """
    Compute metrics for one sample.
    """
    # Convert to numpy (matching and RMSE functions are numpy-based)
    device, dtype = x.device, x.dtype

    pred_np = x[:, :3].cpu().numpy()
    gt_np = x_gt[:, :3].cpu().numpy()
    is_significant_np = s.cpu().numpy()

    num_gts = is_significant_np.sum().item()
    # Perform matching between predicted and gt coordinates.
    matched_pred_idx, matched_gt_idx = match(set1=pred_np, set2=gt_np)

    # Filter out matches that correspond to non-significant ground truths.
    significant_match = is_significant_np[matched_gt_idx]
    num_negligible_matches = len(significant_match) - significant_match.sum().item()
    num_matches = len(matched_gt_idx) - num_negligible_matches
    num_preds = len(pred_np) - num_negligible_matches

    # Compute difference in number counts.
    diff_nums = num_preds - num_gts

    # Compute TP, FP, FN from analytic helper.
    tp, fp, fn = analytics.compute_tp_fp_fn(
        num_gts=num_gts, num_matches=num_matches, num_preds=num_preds
    )
    jaccard = analytics.compute_jaccard(tp=tp, fp=fp, fn=fn)
    precision = analytics.compute_precision(tp=tp, fp=fp)
    recall = analytics.compute_recall(tp=tp, fn=fn)

    # Select only matches with significant ground truths.
    matched_gt = gt_np[matched_gt_idx][significant_match]
    matched_pred = pred_np[matched_pred_idx][significant_match]
    offset = (matched_gt - matched_pred).sum(axis=0)
    offset = torch.from_numpy(offset).to(device=device, dtype=dtype)

    # Compute RMSE values.
    rmse_lat = compute_rmse(
        matched_pred=matched_pred[:, :2], matched_gt=matched_gt[:, :2]
    )
    rmse_axial = compute_rmse(
        matched_pred=matched_pred[:, 2:], matched_gt=matched_gt[:, 2:]
    )
    rmse_vol = compute_rmse(matched_pred=matched_pred, matched_gt=matched_gt)
    # Compute composite efficiency metric.
    Elat = compute_lat_efficiency(jaccard=jaccard, rmse_lat=rmse_lat)
    Eax = compute_axial_efficiency(jaccard=jaccard, rmse_axial=rmse_axial)
    E = compute_3d_efficiency(jaccard=jaccard, rmse_lat=rmse_lat, rmse_axial=rmse_axial)

    return {
        "nmatches": num_matches,
        "jac": jaccard,
        "prec": precision,
        "rec": recall,
        "n_detects": num_preds,
        "diff_nums": diff_nums,
        "rmse_lat": rmse_lat,
        "rmse_axial": rmse_axial,
        "rmse_vol": rmse_vol,
        "Elat": Elat,
        "Eax": Eax,
        "E": E,
        "offset": offset,
    }


class MatchingMetrics(Metric):
    """
    A TorchMetrics metric that computes several statistics based on a common matching process.

    It expects:
      - preds: a tensor of shape (B, N, D) where D>=3 (the first 3 dimensions are coordinates),
      - target: a tensor of shape (B, M, 3) (ground truth coordinates),
      - is_significant: a tensor of shape (B, M) (boolean significance mask).
    """

    def __init__(self, sync_on_compute: bool = True):
        super().__init__(sync_on_compute=sync_on_compute)

        self.add_state("nbatches", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("nmatches", default=torch.tensor(0.0), dist_reduce_fx="sum")

        # Running sums for metrics; these states will be summed across batches/updates.
        self.add_state("jac", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("prec", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("rec", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("rmse_lat", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("rmse_axial", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("rmse_vol", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("Elat", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("Eax", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("E", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state(
            "offset", default=torch.tensor([0.0, 0.0, 0.0]), dist_reduce_fx="sum"
        )

    def update(self, x: Tensor, x_gt: Tensor, s: Tensor):
        """
        Update state with a batch of predictions.

        Expects:
            preds: tensor of shape (B, N, D)
            target: tensor of shape (B, M, 3)
            is_significant: tensor of shape (B, M)
        """
        bs = len(x)
        for i in range(bs):
            m = _compute_sample_metrics(x=x[i], x_gt=x_gt[i], s=s[i])
            self.nbatches += 1
            self.nmatches += m["nmatches"]
            self.jac += m["jac"]
            self.prec += m["prec"]
            self.rec += m["rec"]
            self.rmse_lat += m["rmse_lat"]
            self.rmse_axial += m["rmse_axial"]
            self.rmse_vol += m["rmse_vol"]
            self.E += m["E"]
            self.Elat += m["Elat"]
            self.Eax += m["Eax"]
            self.offset += m["offset"]

    def compute(self):
        return {
            "jac": self.jac / self.nbatches,
            "prec": self.prec / self.nbatches,
            "rec": self.rec / self.nbatches,
            "rmse_lat": self.rmse_lat / self.nbatches,
            "rmse_axial": self.rmse_axial / self.nbatches,
            "rmse_vol": self.rmse_vol / self.nbatches,
            "Elat": self.Elat / self.nbatches,
            "Eax": self.Eax / self.nbatches,
            "E": self.E / self.nbatches,
            "offset": self.offset / self.nmatches,
        }
