import pandas as pd
import numpy as np
import random
from tqdm import tqdm
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import h5py

# Config
summary_csv_path = 'pairwise_label_summary.csv'
questions_csv_path = 'filtered_object_instance_relative_depth_questions.csv'
output_csv_path = 'final_selected_questions.csv'
mat_file_path = '../../NYU_depth/nyu_metadata/nyu_depth_v2_labeled.mat'

min_threshold = 10
max_threshold = 30
max_pairs_per_label = 5
random_seed = 42
depth_k = 0.5

random.seed(random_seed)

# Load CSVs
summary_df = pd.read_csv(summary_csv_path)
questions_df = pd.read_csv(questions_csv_path)
print(f"Loaded {len(summary_df)} pairs and {len(questions_df)} questions.")

# Load .mat files once
with h5py.File(mat_file_path, "r") as f:
    labels_all = [f["labels"][i][()].T for i in range(f["labels"].shape[0])]
    instances_all = [f["instances"][i][()].T for i in range(f["instances"].shape[0])]
    depths_all = [f["depths"][i][()].T for i in range(f["depths"].shape[0])]
print("Preloaded all NYU label, instance, and depth maps")

# Step -1: Filter by depth difference
def passes_depth_check(row):
    try:
        image_id = int(row.image_id)
        if image_id >= len(labels_all):
            return False

        l1, i1 = int(row.object1_label_id), int(row.object1_instance_id)
        l2, i2 = int(row.object2_label_id), int(row.object2_instance_id)

        label_frame = labels_all[image_id]
        instance_frame = instances_all[image_id]
        depth_frame = depths_all[image_id]

        mask1 = (label_frame == l1) & (instance_frame == i1)
        mask2 = (label_frame == l2) & (instance_frame == i2)

        if not np.any(mask1) or not np.any(mask2):
            return False

        depth1 = np.min(depth_frame[mask1])
        depth2 = np.min(depth_frame[mask2])

        return abs(depth1 - depth2) >= depth_k
    except:
        return False

print("Step -1: Filtering by depth difference (|Δ| ≥ 0.5)...")
questions_df['depth_ok'] = [
    passes_depth_check(row)
    for row in tqdm(questions_df.itertuples(index=False), total=len(questions_df))
]
questions_df = questions_df[questions_df['depth_ok']].drop(columns='depth_ok')
print(f"Remaining after depth filtering: {len(questions_df)}\n")

# Step 0: Filter out overlapping or missing bboxes
def find_bbox(label_frame, instance_frame, label_id, instance_id):
    mask = (label_frame == int(label_id)) & (instance_frame == int(instance_id))
    coords = np.argwhere(mask)
    if coords.size == 0:
        return None
    y0, x0 = coords.min(axis=0)
    y1, x1 = coords.max(axis=0)
    return (x0, y0, x1, y1)

def bboxes_overlap(b1, b2):
    x1_min, y1_min, x1_max, y1_max = b1
    x2_min, y2_min, x2_max, y2_max = b2
    return not (
        x1_max <= x2_min or x2_max <= x1_min or
        y1_max <= y2_min or y2_max <= y1_min
    )

def is_bbox_valid_and_nonoverlapping(row):
    try:
        image_id = int(row.image_id)
        if image_id >= len(labels_all):
            return False

        l1, i1 = int(row.object1_label_id), int(row.object1_instance_id)
        l2, i2 = int(row.object2_label_id), int(row.object2_instance_id)

        label_frame = labels_all[image_id]
        instance_frame = instances_all[image_id]

        bbox1 = find_bbox(label_frame, instance_frame, l1, i1)
        bbox2 = find_bbox(label_frame, instance_frame, l2, i2)

        if bbox1 is None or bbox2 is None:
            return False
        if bboxes_overlap(bbox1, bbox2):
            return False

        return True
    except:
        return False

print("Step 0: Filtering out overlapping or missing bboxes...")
questions_df['bbox_ok'] = [
    is_bbox_valid_and_nonoverlapping(row)
    for row in tqdm(questions_df.itertuples(index=False), total=len(questions_df), desc="Checking BBoxes")
]
questions_df = questions_df[questions_df['bbox_ok']].drop(columns='bbox_ok')
print(f"Remaining after bbox check: {len(questions_df)}\n")


# Step 1: Find valid pairs (both ≥ min_threshold)
valid_pairs = set()
for _, row in summary_df.iterrows():
    l1, l2 = row['Tuple Pair'].strip("()").split(", ")
    if row['#QuestionwithAnswerL1'] >= min_threshold and row['#QuestionwithAnswerL2'] >= min_threshold:
        valid_pairs.add(tuple(sorted([l1, l2])))

print(f"Valid pairs (≥{min_threshold} per label): {len(valid_pairs)}")


# Step 2: Keep only valid pairs
def make_tuple(row):
    return tuple(sorted([row['object1'], row['object2']]))

questions_df['Tuple Pair'] = questions_df.apply(make_tuple, axis=1)
filtered_questions = questions_df[questions_df['Tuple Pair'].isin(valid_pairs)]
print(f"Filtered questions after valid pairs: {len(filtered_questions)}")


# Step 3: Precompute candidate rows with sampling
pair_to_rows = {}
label_to_candidate_pairs = defaultdict(set)
total_rows_after_sampling = 0

pair_groups = filtered_questions.groupby('Tuple Pair')

for pair, group in tqdm(pair_groups, desc="Processing pairs"):
    obj1, obj2 = pair
    obj1_correct = group[group['answer(object1)'] == obj1]
    obj2_correct = group[group['answer(object1)'] == obj2]

    if len(obj1_correct) < min_threshold or len(obj2_correct) < min_threshold:
        continue

    obj1_correct = obj1_correct.sample(min(len(obj1_correct), max_threshold), random_state=random_seed)
    obj2_correct = obj2_correct.sample(min(len(obj2_correct), max_threshold), random_state=random_seed)

    pair_data = pd.concat([obj1_correct, obj2_correct])
    pair_to_rows[pair] = pair_data
    total_rows_after_sampling += len(pair_data)

    label_to_candidate_pairs[obj1].add(pair)
    label_to_candidate_pairs[obj2].add(pair)

print(f"Candidate pairs after sampling: {len(pair_to_rows)}")
print(f"Total questions kept after sampling: {total_rows_after_sampling}")


# Step 4: Enforce max pairs per label (greedy strategy)
label_pair_counts = defaultdict(int)
final_kept_pairs = set()

all_pairs = list(pair_to_rows.keys())
random.shuffle(all_pairs)

for pair in all_pairs:
    l1, l2 = pair
    if label_pair_counts[l1] <= max_pairs_per_label and label_pair_counts[l2] <= max_pairs_per_label:
        final_kept_pairs.add(pair)
        label_pair_counts[l1] += 1
        label_pair_counts[l2] += 1

print(f"Final kept pairs (≤{max_pairs_per_label} per label): {len(final_kept_pairs)}")


# Step 5: Aggregate final questions
final_rows = [pair_to_rows[pair] for pair in final_kept_pairs]
final_questions = pd.concat(final_rows).drop_duplicates().reset_index(drop=True)
print(f"Final selected questions after enforcing caps: {len(final_questions)}")


# Step 6: Save to CSV 
save_columns = [
    'image_id', 'question',
    'object1', 'object2',
    'object1_label_id', 'object1_instance_id',
    'object2_label_id', 'object2_instance_id',
    'object1_dis', 'object2_dis',
    'answer(object1)'
]

final_questions[save_columns].to_csv(output_csv_path, index=False)
print(f"Saved final questions to: {output_csv_path}")


# Step 7: Plot histogram
answer_counts = final_questions['answer(object1)'].value_counts().sort_values(ascending=False)

plt.figure(figsize=(14, 6))
plt.bar(answer_counts.index, answer_counts.values)
plt.xticks(rotation=90)
plt.xlabel("Label (Object Name)")
plt.ylabel("Correct Answer Frequency")
plt.title("Correct Answer Frequency (After Filtering and BBox Checks)")
plt.tight_layout()
plt.savefig("afterFiltering_summary.jpg")

