import numpy as np
import itertools
from collections import defaultdict, Counter


# Build a translation table once, at module level
_translation_table = str.maketrans({'_': ' ', '/': ' '})

def clean_label(label: str) -> str:
    """Convert slashes and underscores to spaces (fast)."""
    return label.translate(_translation_table)

def compute_object_metrics_unary(
    gt_object_dict,
    cate_pred,
    precision_thres_ls=[1, 5, 10],
    recall_thres_ls=[1, 5, 10]
):
    """
    Faster version of object classification metrics:
      - Deduplicate by storing max score in a dict
      - Avoid repeated calls to clean_label
    """
    # 1. Build ground-truth object_id -> GT label (cleaned).
    gt_object_labels = {
        oid: clean_label(label_str)
        for (vid, oid, label_str) in gt_object_dict
    }

    # 2. Accumulate the highest score for each (oid, cleaned_label).
    pred_dict = {}
    for (oid, pred_label), score in cate_pred.items():
        clbl = clean_label(pred_label)
        key = (oid, clbl)
        # Keep only the maximum score for that (oid, label) pair
        if key not in pred_dict or pred_dict[key] < score:
            pred_dict[key] = score

    # 3. Convert to a list and sort descending by score
    final_object_preds = [(oid, lbl, sc) for (oid, lbl), sc in pred_dict.items()]
    final_object_preds.sort(key=lambda x: x[2], reverse=True)

    # 4. Compute hits in that sorted order
    is_hit = [
        1 if (oid in gt_object_labels and lbl == gt_object_labels[oid]) else 0
        for (oid, lbl, sc) in final_object_preds
    ]
    cumsum_obj = np.cumsum(is_hit)
    total_objs = len(gt_object_labels)

    # 5. Calculate top-K precision/recall
    result_unary_precision = {}
    for k in precision_thres_ls:
        if len(cumsum_obj) == 0:
            result_unary_precision[k] = 0.0
        else:
            top_k = min(k, len(cumsum_obj))
            result_unary_precision[k] = cumsum_obj[top_k - 1] / float(k)

    result_unary_recall = {}
    for k in recall_thres_ls:
        if len(cumsum_obj) == 0 or total_objs == 0:
            result_unary_recall[k] = 0.0
        else:
            top_k = min(k, len(cumsum_obj))
            result_unary_recall[k] = cumsum_obj[top_k - 1] / float(total_objs)

    return result_unary_precision, result_unary_recall


def _aggregate_window_binary_preds(
    binary_pred_window,
    per_oid_labels,
    rel_weights,
    top_k_classes,
    frame_agg_strategy,
    label_agg_strategy,
):
    """
    Performs the same per-window pipeline that the old code did for
    the entire video, but restricted to a subset of frames (the 'window').
    
    Returns the final list of ( (s_label, rel, o_label), score ), sorted descending.
    """
    # Expand each (fid, (sid, oid), rel_str) to subject/object top-k labels
    updated_binary_pred = {}
    for (key, base_score) in binary_pred_window.items():
        # key is (fid, (sid, oid), rel_str)
        fid, (sid, oid), rel_str = key

        # If base_score is a tensor, detach + convert to float
        if hasattr(base_score, "detach"):
            base_score_val = float(base_score.detach().cpu().item())
        else:
            base_score_val = float(base_score)

        # Pull the top-k subject/object candidates
        subject_candidates = per_oid_labels[sid][:top_k_classes]
        object_candidates  = per_oid_labels[oid][:top_k_classes]

        rel_str_clean = clean_label(rel_str)

        # Combine each subject-label candidate with each object-label candidate
        for (subj_label, subj_prob) in subject_candidates:
            for (obj_label, obj_prob) in object_candidates:
                # Example weighting usage (disabled by default):
                #   w_subj = rel_weights[rel_str_clean]["subject"].get(subj_label, 0.0) \
                #       if (rel_weights and rel_str_clean in rel_weights) else 0.0
                #   w_obj  = rel_weights[rel_str_clean]["object"].get(obj_label, 0.0) \
                #       if (rel_weights and rel_str_clean in rel_weights) else 0.0
                #   new_score_val = base_score_val * w_subj * w_obj
                new_score_val = base_score_val  # simplest case

                new_key = (fid, (sid, oid), rel_str_clean, subj_label, obj_label)
                if new_key not in updated_binary_pred or updated_binary_pred[new_key] < new_score_val:
                    updated_binary_pred[new_key] = new_score_val

    # Framewise aggregation => group by (sid, rel_str, oid, subj_label, obj_label)
    framewise_groups = defaultdict(list)
    for (
        (fid, (sid, oid), rel_str_clean, subj_cand, obj_cand),
        score,
    ) in updated_binary_pred.items():
        group_key = (sid, rel_str_clean, oid, subj_cand, obj_cand)
        framewise_groups[group_key].append(score)

    def aggregate(values, strategy="max"):
        if not values:
            return 0.0
        if strategy == "max":
            return max(values)
        elif strategy == "sum":
            return sum(values)
        elif strategy in ("mean", "avg"):
            return sum(values) / len(values)
        else:
            raise ValueError(f"Unsupported aggregator: {strategy}")

    aggregated_by_id = {}
    for group_key, scores in framewise_groups.items():
        aggregated_by_id[group_key] = aggregate(scores, frame_agg_strategy)

    # Group by predicted labels ignoring object IDs
    labelwise_groups = defaultdict(list)
    for (
        (sid, rel_str_clean, oid, subj_cand, obj_cand),
        score,
    ) in aggregated_by_id.items():
        key_label = (subj_cand, rel_str_clean, obj_cand)
        labelwise_groups[key_label].append(score)

    aggregated_by_label = {}
    for key_label, scores in labelwise_groups.items():
        aggregated_by_label[key_label] = aggregate(scores, label_agg_strategy)

    # => final list of ( (s_label, rel_str, o_label), score ), sorted descending
    final_rel_list = sorted(
        aggregated_by_label.items(),
        key=lambda x: x[1],
        reverse=True
    )

    return final_rel_list


def _build_full_updated_binary_pred(binary_pred, per_oid_labels, rel_weights, top_k_classes):
    """
    Same expansion logic as the old code for building confusion matrix across the full video
    (ignoring windows). That is, for each (fid, sid, oid, rel_str), we expand subject/object
    labels with top_k_classes and store them all in a dict. This is used only to form the
    final 'binary_confusion' matrix in a single shot.
    """
    updated_binary_pred = {}
    for (key, base_score) in binary_pred.items():
        # key is (fid, (sid, oid), rel_str)
        try:
            fid, (sid, oid), rel_str = key
        except:
            continue
        # If base_score is a tensor, detach + convert to float
        if hasattr(base_score, "detach"):
            base_score_val = float(base_score.detach().cpu().item())
        else:
            base_score_val = float(base_score)

        subject_candidates = per_oid_labels[sid][:top_k_classes]
        object_candidates  = per_oid_labels[oid][:top_k_classes]
        rel_str_clean = clean_label(rel_str)

        for (subj_label, subj_prob) in subject_candidates:
            for (obj_label, obj_prob) in object_candidates:
                new_key = (fid, (sid, oid), rel_str_clean, subj_label, obj_label)
                # Could also apply weighting if desired
                if new_key not in updated_binary_pred or updated_binary_pred[new_key] < base_score_val:
                    updated_binary_pred[new_key] = base_score_val

    return updated_binary_pred


def compute_metrics_top_k(
    gt_object_dict,
    gt_object_rels,
    cate_pred,
    binary_pred,
    rel_weights=None,
    frame_agg_strategy="avg",
    label_agg_strategy="max",
    precision_thres_ls=[1, 5, 10],
    recall_thres_ls=[1, 5, 10],
    top_k_classes=3,
    window_size=30,
    window_stride=15,
    html_path=None,
    all_objects=None,      # complete list of object classes
    all_predicates=None,   # complete list of predicate classes
):
    """
    Uses "mode"-like frequency counting instead of merging scores across windows.

    - For unary metrics: same logic as before (top-K over entire set).
    - For binary metrics:
        1) Slide over frames [0..max_fid], in increments of 'window_stride'.
        2) For each window, gather the final list of (triple, score). Sort descending by score.
        3) For each threshold K, take the top-K from that window (ignoring scores).
           Increment a 'count' for each triple that appeared.
        4) After all windows, convert triple->count into a list sorted descending by count.
        5) Precision@K and Recall@K are computed by taking the top-K from that 
           frequency-sorted list and computing cumsum hits as usual.
    """

    # ------------------------------------------------------------------------
    # 1. Object classification metrics (unary) stays exactly the same
    # ------------------------------------------------------------------------
    result_unary_precision, result_unary_recall = compute_object_metrics_unary(
        gt_object_dict=gt_object_dict,
        cate_pred=cate_pred,
        precision_thres_ls=precision_thres_ls,
        recall_thres_ls=recall_thres_ls,
    )

    # ------------------------------------------------------------------------
    # 2. Build a predicted OID -> sorted list of (clean_label, score)
    # ------------------------------------------------------------------------
    gt_oid_to_label = {
        oid: clean_label(lbl)
        for (vid, oid, lbl) in gt_object_dict
    }

    per_oid_labels = defaultdict(list)
    for (oid, raw_label), sc in cate_pred.items():
        clbl = clean_label(raw_label)
        per_oid_labels[oid].append((clbl, sc))

    # Sort each OID’s list (descending by score)
    for oid in per_oid_labels:
        per_oid_labels[oid].sort(key=lambda x: x[1], reverse=True)

    # For confusion matrix usage
    predicted_oid_to_label = {}
    for oid, label_list in per_oid_labels.items():
        predicted_oid_to_label[oid] = label_list[0][0]  # top-1 label string

    # ------------------------------------------------------------------------
    # 3. Ground truth set of (s_label, rel, o_label), ignoring frames
    # ------------------------------------------------------------------------
    gt_rel_set = set()
    for frame_rels in gt_object_rels:  # each item is a list of (sid, oid, rel_str)
        for (sid, oid, rel_str) in frame_rels:
            s_label = gt_oid_to_label.get(sid, "UNKNOWN")
            o_label = gt_oid_to_label.get(oid, "UNKNOWN")
            rel_str_clean = clean_label(rel_str)
            gt_rel_set.add((s_label, rel_str_clean, o_label))

    # ------------------------------------------------------------------------
    # 4. Window-based logic for binary: gather final_rel_list for each window
    # ------------------------------------------------------------------------
    if not binary_pred:
        # Just return zeros for binary
        result_binary_precision = {k: 0.0 for k in precision_thres_ls}
        result_binary_recall = {k: 0.0 for k in recall_thres_ls}

    else:
        all_frames = [key[0] for key in binary_pred.keys()]  # each key = (fid, (sid, oid), rel_str)
        max_fid = max(all_frames) if all_frames else 0

        # For each window, compute final_rel_list (like before).
        # We'll store them so we can re-use them for every threshold.
        window_outputs = []
        win_start = 0
        while win_start <= max_fid:
            win_end = win_start + window_size
            if win_end > max_fid:
                win_end = max_fid

            # Filter predictions to frames in [win_start, win_end]
            binary_pred_window = {
                pkey: pscore
                for (pkey, pscore) in binary_pred.items()
                if win_start <= pkey[0] <= win_end
            }
            if binary_pred_window:
                window_output = _aggregate_window_binary_preds(
                    binary_pred_window,
                    per_oid_labels,
                    rel_weights,
                    top_k_classes,
                    frame_agg_strategy,
                    label_agg_strategy,
                )
                # window_output is sorted desc by 'score', 
                # e.g. [ ((s_label, rel, o_label), score), ... ]
                window_outputs.append(window_output)
            if win_end == max_fid:
                break
            win_start += window_stride

        # Now for each threshold, we do the "mode"-style count. 
        union_thresholds = sorted(set(precision_thres_ls + recall_thres_ls))
        agg_results = {}

        for K in union_thresholds:
            # We'll count how many windows had triple in top-K
            triple_to_count = defaultdict(int)

            for window_output in window_outputs:
                # top-K from that window
                top_k_window = window_output[:K]
                for (triple, _) in top_k_window:
                    triple_to_count[triple] += 1

            # Convert triple_to_count -> list sorted by descending count
            # If counts tie, it's arbitrary which triple goes first.
            final_list = sorted(triple_to_count.items(), key=lambda x: x[1], reverse=True)
            # hits = 1 if triple is in gt, else 0
            hits = [1 if triple in gt_rel_set else 0 for (triple, cnt) in final_list]
            cumsum_hits = np.cumsum(hits)
            total_rels = len(gt_rel_set)

            if len(cumsum_hits) == 0:
                prec_val = 0.0
                rec_val = 0.0
            else:
                # top-K from this final freq-sorted list
                top_k_idx = min(K, len(cumsum_hits)) - 1
                prec_val = float(cumsum_hits[top_k_idx]) / float(K)
                rec_val = 0.0 if total_rels == 0 else float(cumsum_hits[top_k_idx]) / float(total_rels)

            agg_results[K] = (prec_val, rec_val)

        result_binary_precision = {k: agg_results[k][0] for k in precision_thres_ls}
        result_binary_recall    = {k: agg_results[k][1] for k in recall_thres_ls}

    # ------------------------------------------------------------------------
    # 5. Confusion matrices (unchanged logic)
    # ------------------------------------------------------------------------
    object_confusion = None
    if all_objects:
        object_confusion = np.zeros((len(all_objects), len(all_objects)), dtype=int)
        object_to_index = {obj: idx for idx, obj in enumerate(all_objects)}
        for (vid, oid, label) in gt_object_dict:
            gt_lbl = clean_label(label)
            pred_lbl = predicted_oid_to_label.get(oid, "UNKNOWN")
            if gt_lbl in object_to_index and pred_lbl in object_to_index:
                object_confusion[object_to_index[gt_lbl], object_to_index[pred_lbl]] += 1

    binary_confusion = None
    if all_predicates:
        # This part still uses the old approach of building distributions 
        # across the entire video. We haven't changed confusion logic to "mode" style.
        pred_per_pair = defaultdict(list)
        updated_binary_pred = _build_full_updated_binary_pred(
            binary_pred,
            per_oid_labels,
            rel_weights,
            top_k_classes,
        )
        for (fid, (sid, oid), rel_str_clean, subj_label, obj_label), score in updated_binary_pred.items():
            pred_per_pair[(sid, oid)].append((rel_str_clean, score))

        pred_pair_dist = {}
        for pair, pred_list in pred_per_pair.items():
            score_counter = defaultdict(float)
            for rel, sc in pred_list:
                score_counter[rel] += sc
            total_score = sum(score_counter.values())
            if total_score > 0:
                dist = {rel: sc / total_score for rel, sc in score_counter.items()}
            else:
                dist = {}
            pred_pair_dist[pair] = dist

        # Group ground truth relations by object pair (across frames)
        gt_per_pair = defaultdict(list)
        for frame_rels in gt_object_rels:
            for (sid, oid, rel_str) in frame_rels:
                gt_per_pair[(sid, oid)].append(clean_label(rel_str))

        gt_pair_dist = {}
        for pair, rel_list in gt_per_pair.items():
            counts = Counter(rel_list)
            total_counts = sum(counts.values())
            if total_counts > 0:
                dist = {rel: counts[rel] / total_counts for rel in counts}
            else:
                dist = {}
            gt_pair_dist[pair] = dist

        # Build confusion matrix
        predicate_to_index = {pred: idx for idx, pred in enumerate(all_predicates)}
        confusion_matrix = np.zeros((len(all_predicates), len(all_predicates)), dtype=float)

        for pair, gt_dist in gt_pair_dist.items():
            if pair not in pred_pair_dist:
                continue
            pred_dist = pred_pair_dist[pair]
            # Build full distribution vectors for all classes.
            gt_vec = np.zeros(len(all_predicates))
            pred_vec = np.zeros(len(all_predicates))
            for rel, prob in gt_dist.items():
                if rel in predicate_to_index:
                    gt_vec[predicate_to_index[rel]] = prob
            for rel, prob in pred_dist.items():
                if rel in predicate_to_index:
                    pred_vec[predicate_to_index[rel]] = prob
            confusion_matrix += np.outer(gt_vec, pred_vec)

        binary_confusion = confusion_matrix

    # Final results dictionary
    results = {
        "precision": {
            "cate": result_unary_precision,
            "binary": result_binary_precision
        },
        "recall": {
            "cate": result_unary_recall,
            "binary": result_binary_recall
        }
    }

    return results, object_confusion, binary_confusion


def _build_full_updated_binary_pred(binary_pred, per_oid_labels, rel_weights, top_k_classes):
    """
    For confusion-matrix usage, not for final "mode" merging.
    Same as your old approach:
      - expand (fid, sid, oid, rel_str) to top-k subject/object labels
      - store them in a dict keyed by (fid, (sid, oid), rel_str_clean, subj_label, obj_label).
    """
    updated_binary_pred = {}
    for (key, base_score) in binary_pred.items():
        # key is (fid, (sid, oid), rel_str)
        try:
            fid, (sid, oid), rel_str = key
        except:
            continue
        # If base_score is a tensor, detach + convert to float
        if hasattr(base_score, "detach"):
            base_score_val = float(base_score.detach().cpu().item())
        else:
            base_score_val = float(base_score)

        subject_candidates = per_oid_labels[sid][:top_k_classes]
        object_candidates  = per_oid_labels[oid][:top_k_classes]
        rel_str_clean = clean_label(rel_str)

        for (subj_label, subj_prob) in subject_candidates:
            for (obj_label, obj_prob) in object_candidates:
                new_key = (fid, (sid, oid), rel_str_clean, subj_label, obj_label)
                if new_key not in updated_binary_pred or updated_binary_pred[new_key] < base_score_val:
                    updated_binary_pred[new_key] = base_score_val

    return updated_binary_pred
