"""VG Zero-shot Scene Graph Evaluation Script

Evaluates scene graph predictions from scene_graphs.json against VG ground truth.
Only evaluates images that have predicted relations.
"""

import json
import os
import numpy as np
import h5py
from sentence_transformers import SentenceTransformer
import torch
import copy
from tqdm import tqdm

# ============== Configuration ==============
SCENE_GRAPH_PATH = "/home//hsgg/output/vg_zero/scene_graphs.json"
H5_FILE = '/home//hsgg/dataset/vg/VG-SGG-with-attri.h5'
IMAGE_DATA_JSON = '/home//hsgg/dataset/vg/image_data.json'
VG_DICT_PATH = "/home//hsgg/dataset/vg/VG-SGG-dicts-with-attri.json"
IMG_DIR = "/home//hsgg/dataset/vg/VG_100K"
# Use sentence-transformers model (can be changed to local path)
EMBEDDING_MODEL_PATH = "/Qwen/Qwen3-Embedding-4B"  # Default model, change if needed


# ============== Load GT Data ==============
def load_graphs(roidb_file, split, num_im, num_val_im, filter_empty_rels, img_dir, BOX_SCALE=1024, image_info_json=None):
    """Load the GT boxes and relations from VG dataset."""
    with open(image_info_json, 'r') as f:
        im_data = json.load(f)

    corrupted_ims = ['1592.jpg', '1722.jpg', '4616.jpg', '4617.jpg']
    fns = []
    img_info = []
    for i, img in enumerate(im_data):
        basename = '{}.jpg'.format(img['image_id'])
        if basename in corrupted_ims:
            continue
        filename = os.path.join(img_dir, basename)
        if os.path.exists(filename):
            fns.append(filename)
            img_info.append(img)

    roi_h5 = h5py.File(roidb_file, 'r')
    data_split = roi_h5['split'][:]
    split_flag = 2 if split == 'test' else 0
    split_mask = data_split == split_flag

    split_mask &= roi_h5['img_to_first_box'][:] >= 0
    if filter_empty_rels:
        split_mask &= roi_h5['img_to_first_rel'][:] >= 0

    image_index = np.where(split_mask)[0]
    if num_im > -1:
        image_index = image_index[:num_im]
    if num_val_im > 0:
        if split == 'val':
            image_index = image_index[:num_val_im]
        elif split == 'train':
            image_index = image_index[num_val_im:]

    split_mask = np.zeros_like(data_split).astype(bool)
    split_mask[image_index] = True

    all_labels = roi_h5['labels'][:, 0]
    all_boxes = roi_h5['boxes_{}'.format(BOX_SCALE)][:]

    # convert from xc, yc, w, h to x1, y1, x2, y2
    all_boxes[:, :2] = all_boxes[:, :2] - all_boxes[:, 2:] / 2
    all_boxes[:, 2:] = all_boxes[:, :2] + all_boxes[:, 2:]

    im_to_first_box = roi_h5['img_to_first_box'][split_mask]
    im_to_last_box = roi_h5['img_to_last_box'][split_mask]
    im_to_first_rel = roi_h5['img_to_first_rel'][split_mask]
    im_to_last_rel = roi_h5['img_to_last_rel'][split_mask]

    _relations = roi_h5['relationships'][:]
    _relation_predicates = roi_h5['predicates'][:, 0]

    boxes = []
    gt_classes = []
    relationships = []
    for i in range(len(image_index)):
        i_obj_start = im_to_first_box[i]
        i_obj_end = im_to_last_box[i]
        i_rel_start = im_to_first_rel[i]
        i_rel_end = im_to_last_rel[i]

        boxes_i = all_boxes[i_obj_start : i_obj_end + 1, :].copy()
        w, h = img_info[image_index[i]]['width'], img_info[image_index[i]]['height']
        scale = max(w, h) / BOX_SCALE
        boxes_i = boxes_i * scale
        gt_classes_i = all_labels[i_obj_start : i_obj_end + 1]

        if i_rel_start >= 0:
            predicates = _relation_predicates[i_rel_start : i_rel_end + 1]
            obj_idx = _relations[i_rel_start : i_rel_end + 1] - i_obj_start
            rels = np.column_stack((obj_idx, predicates))
        else:
            rels = np.zeros((0, 3), dtype=np.int32)

        boxes.append(boxes_i)
        gt_classes.append(gt_classes_i)
        relationships.append(rels)

    filenames = [fns[i] for i in np.where(split_mask)[0]]
    return filenames, boxes, gt_classes, relationships


def compute_iou(box1, box2):
    """Compute IoU between two boxes (x1, y1, x2, y2)."""
    x1_min, y1_min, x1_max, y1_max = box1
    x2_min, y2_min, x2_max, y2_max = box2

    inter_xmin = max(x1_min, x2_min)
    inter_ymin = max(y1_min, y2_min)
    inter_xmax = min(x1_max, x2_max)
    inter_ymax = min(y1_max, y2_max)

    inter_w = max(0, inter_xmax - inter_xmin)
    inter_h = max(0, inter_ymax - inter_ymin)
    inter_area = inter_w * inter_h

    area1 = (x1_max - x1_min) * (y1_max - y1_min)
    area2 = (x2_max - x2_min) * (y2_max - y2_min)
    union_area = area1 + area2 - inter_area

    if union_area == 0:
        return 0.0
    return inter_area / union_area


def _get_predicate(rel):
    """Extract predicate from relation."""
    if isinstance(rel, np.ndarray):
        rel = rel.tolist()
    if isinstance(rel, (list, tuple)) and len(rel) >= 2:
        return str(rel[1])
    return str(rel)


def recall_at_k(gt_boxes, gt_rels, pred_boxes, pred_rels, pred_scores, k, iou_threshold=0.5):
    """Compute Recall@K for scene graph generation."""
    if len(pred_boxes) == 0 or len(gt_boxes) == 0:
        return 0.0

    # Sort by scores and take top-k
    pred_scores = np.array(pred_scores) if isinstance(pred_scores, list) else pred_scores
    order = np.argsort(-pred_scores)
    k_eff = min(k, len(order))
    sel = order[:k_eff]

    pred_boxes_k = np.array(pred_boxes)[sel] if len(sel) > 0 else np.array([])
    pred_rels_k = np.array(pred_rels)[sel] if len(sel) > 0 else np.array([])

    matched = 0
    for gi, gt in enumerate(gt_boxes):
        gt_subj = gt[:4]
        gt_obj = gt[4:]
        gt_pred = _get_predicate(gt_rels[gi])

        for j, pred in enumerate(pred_boxes_k):
            pred_subj = pred[:4]
            pred_obj = pred[4:]
            pred_pred = _get_predicate(pred_rels_k[j])

            # Direct match
            iou_s1 = compute_iou(pred_subj, gt_subj)
            iou_o1 = compute_iou(pred_obj, gt_obj)
            if iou_s1 >= iou_threshold and iou_o1 >= iou_threshold and pred_pred == gt_pred:
                matched += 1
                break

            # Swapped match (undirected relation)
            iou_s2 = compute_iou(pred_subj, gt_obj)
            iou_o2 = compute_iou(pred_obj, gt_subj)
            if iou_s2 >= iou_threshold and iou_o2 >= iou_threshold and pred_pred == gt_pred:
                matched += 1
                break

    return matched / max(1, len(gt_boxes))


def main():
    print("=" * 60)
    print("VG Zero-shot Scene Graph Evaluation")
    print("=" * 60)

    # Load predictions
    print("\n[1/5] Loading scene graph predictions...")
    with open(SCENE_GRAPH_PATH, 'r') as f:
        scene_graphs = json.load(f)

    # Filter images with relations
    scene_graphs_with_relations = [sg for sg in scene_graphs if sg.get('relations') and len(sg['relations']) > 0]
    print(f"Total images: {len(scene_graphs)}, Images with relations: {len(scene_graphs_with_relations)}")

    # Create filename to prediction mapping
    pred_by_filename = {}
    for sg in scene_graphs_with_relations:
        filename = os.path.basename(sg['image_path'])
        pred_by_filename[filename] = sg

    # Load GT data
    print("\n[2/5] Loading VG ground truth...")
    gt_filenames, gt_boxes_all, gt_classes_all, gt_relationships_all = load_graphs(
        H5_FILE, "test", -1, num_val_im=5000,
        filter_empty_rels=True, img_dir=IMG_DIR, image_info_json=IMAGE_DATA_JSON
    )
    gt_filenames = [os.path.basename(f) for f in gt_filenames]
    print(f"GT test images: {len(gt_filenames)}")

    # Load VG label dictionaries
    with open(VG_DICT_PATH, 'r') as f:
        vg_dict = json.load(f)
    gt_labels_list = list(vg_dict['idx_to_label'].values())
    gt_rels_list = list(vg_dict['idx_to_predicate'].values())

    # Find common images
    common_filenames = [f for f in gt_filenames if f in pred_by_filename]
    print(f"Common images for evaluation: {len(common_filenames)}")

    if len(common_filenames) == 0:
        print("ERROR: No common images found!")
        return

    # Prepare evaluation data
    print("\n[3/5] Preparing evaluation data...")
    eval_gt_boxes = []
    eval_gt_rels = []
    eval_pred_boxes = []
    eval_pred_rels = []
    eval_pred_scores = []

    for filename in tqdm(common_filenames, desc="Processing images"):
        gt_idx = gt_filenames.index(filename)
        pred_sg = pred_by_filename[filename]

        # Process GT data
        gt_rel = gt_relationships_all[gt_idx]
        gt_box = gt_boxes_all[gt_idx]
        gt_cls = gt_classes_all[gt_idx]

        # Deduplicate GT relations
        seen = set()
        gt_rel_dedup = []
        for r in gt_rel:
            t = tuple(map(int, r))
            if t not in seen:
                gt_rel_dedup.append(r)
                seen.add(t)
        gt_rel = np.array(gt_rel_dedup, dtype=np.int32) if gt_rel_dedup else np.zeros((0, 3), dtype=np.int32)

        # Convert GT to box pairs and text relations
        gt_rel_boxes = []
        gt_rel_text = []
        for rel in gt_rel:
            subj_idx, obj_idx, pred_idx = int(rel[0]), int(rel[1]), int(rel[2])
            subj_box = gt_box[subj_idx].tolist()
            obj_box = gt_box[obj_idx].tolist()
            gt_rel_boxes.append(subj_box + obj_box)

            subj_label = gt_labels_list[int(gt_cls[subj_idx]) - 1]
            obj_label = gt_labels_list[int(gt_cls[obj_idx]) - 1]
            pred_label = gt_rels_list[pred_idx - 1]
            gt_rel_text.append([subj_label, pred_label, obj_label])

        # Process predictions
        pred_objects = pred_sg['objects']
        pred_relations = pred_sg['relations']

        pred_rel_boxes = []
        pred_rel_text = []
        pred_rel_scores = []

        for rel in pred_relations:
            subj_idx = rel['idx'][0]
            obj_idx = rel['idx'][1]
            predicate = rel['predicate']

            # Get boxes
            subj_box = pred_objects['boxes'][subj_idx]
            obj_box = pred_objects['boxes'][obj_idx]
            pred_rel_boxes.append(subj_box + obj_box)

            # Get labels (remove numeric suffix like "person1" -> "person")
            subj_label = ''.join(c for c in rel['subject_label'] if not c.isdigit())
            obj_label = ''.join(c for c in rel['object_label'] if not c.isdigit())
            pred_rel_text.append([subj_label, predicate, obj_label])

            # Compute score (average of subject and object detection scores)
            subj_score = pred_objects['scores'][subj_idx] if subj_idx < len(pred_objects['scores']) else 0.5
            obj_score = pred_objects['scores'][obj_idx] if obj_idx < len(pred_objects['scores']) else 0.5
            pred_rel_scores.append((subj_score + obj_score) / 2)

        eval_gt_boxes.append(np.array(gt_rel_boxes))
        eval_gt_rels.append(np.array(gt_rel_text))
        eval_pred_boxes.append(np.array(pred_rel_boxes) if pred_rel_boxes else np.array([]).reshape(0, 8))
        eval_pred_rels.append(np.array(pred_rel_text) if pred_rel_text else np.array([]).reshape(0, 3))
        eval_pred_scores.append(np.array(pred_rel_scores))

    # Load embedding model for label mapping
    print("\n[4/5] Loading embedding model for label mapping...")
    model = SentenceTransformer(EMBEDDING_MODEL_PATH, device="cuda", tokenizer_kwargs={"padding_side": "left"})

    # Collect all predicted labels and predicates
    all_pred_labels = set()
    all_pred_predicates = set()
    for rels in eval_pred_rels:
        for rel in rels:
            if len(rel) >= 3:
                all_pred_labels.add(rel[0])
                all_pred_labels.add(rel[2])
                all_pred_predicates.add(rel[1])

    all_pred_labels = list(all_pred_labels)
    all_pred_predicates = list(all_pred_predicates)

    # Compute label mappings
    print("Computing label embeddings...")
    if all_pred_labels:
        label_emb = model.encode(all_pred_labels)
        gt_label_emb = model.encode(gt_labels_list)
        label_sim = model.similarity(label_emb, gt_label_emb)
        if isinstance(label_sim, torch.Tensor):
            label_sim = label_sim.cpu().numpy()
        label_mapping = {all_pred_labels[i]: gt_labels_list[label_sim[i].argmax()] for i in range(len(all_pred_labels))}
    else:
        label_mapping = {}

    print("Computing predicate embeddings...")
    if all_pred_predicates:
        pred_emb = model.encode(all_pred_predicates)
        gt_pred_emb = model.encode(gt_rels_list)
        pred_sim = model.similarity(pred_emb, gt_pred_emb)
        if isinstance(pred_sim, torch.Tensor):
            pred_sim = pred_sim.cpu().numpy()
        predicate_mapping = {all_pred_predicates[i]: gt_rels_list[pred_sim[i].argmax()] for i in range(len(all_pred_predicates))}
    else:
        predicate_mapping = {}

    # Apply mappings
    print("Applying label mappings...")
    for i in range(len(eval_pred_rels)):
        mapped_rels = []
        for rel in eval_pred_rels[i]:
            if len(rel) >= 3:
                subj = label_mapping.get(rel[0], rel[0])
                pred = predicate_mapping.get(rel[1], rel[1])
                obj = label_mapping.get(rel[2], rel[2])
                mapped_rels.append([subj, pred, obj])
        eval_pred_rels[i] = np.array(mapped_rels) if mapped_rels else np.array([]).reshape(0, 3)

    # Compute metrics
    print("\n[5/5] Computing evaluation metrics...")
    k_values = [20, 50, 100, 150, 200]
    recall_scores = {k: 0.0 for k in k_values}
    recall_all = 0.0

    for i in tqdm(range(len(common_filenames)), desc="Evaluating"):
        gt_boxes = eval_gt_boxes[i]
        gt_rels = eval_gt_rels[i]
        pred_boxes = eval_pred_boxes[i]
        pred_rels = eval_pred_rels[i]
        pred_scores = eval_pred_scores[i]

        if len(gt_boxes) == 0:
            continue

        # Recall@All
        recall_all += recall_at_k(gt_boxes, gt_rels, pred_boxes, pred_rels, pred_scores, len(pred_scores), iou_threshold=0.5)

        # Recall@K
        for k in k_values:
            recall_scores[k] += recall_at_k(gt_boxes, gt_rels, pred_boxes, pred_rels, pred_scores, k, iou_threshold=0.5)

    n_images = len(common_filenames)
    recall_all /= n_images
    for k in k_values:
        recall_scores[k] /= n_images

    # Print results
    print("\n" + "=" * 60)
    print("EVALUATION RESULTS")
    print("=" * 60)
    print(f"Number of evaluated images: {n_images}")
    print(f"\nRecall@All: {recall_all:.4f} ({recall_all * 100:.2f}%)")
    for k in k_values:
        print(f"Recall@{k}: {recall_scores[k]:.4f} ({recall_scores[k] * 100:.2f}%)")

    # Save results
    results = {
        "num_images": n_images,
        "recall_all": recall_all,
        **{f"recall_{k}": recall_scores[k] for k in k_values}
    }
    output_path = os.path.join(os.path.dirname(SCENE_GRAPH_PATH), "eval_results.json")
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to: {output_path}")


if __name__ == "__main__":
    main()
