import pandas as pd
import matplotlib.pyplot as plt
import os
import random
from tqdm import tqdm
import random
random.seed(42)

# Settings
csv_path = "bbox_metadata_all.csv"
bin_edges = list(range(1, 13))  # Depth bins from 1 to 12 meters
filterLen = 3 # labels with >= 3 depth bins
min_per_bin = 10
max_per_bin = 30
minSize = 500 # min size for bbox after resizing to 384*384 (original 640*480)

# Load and bin
df = pd.read_csv(csv_path)
df["depth_bin"] = pd.cut(df["depth"], bins=bin_edges, right=True)

# STEP 1: Filter labels with ≥ 3 depth bins and threshold for resized bbox
# Scale bbox area to match resized 384×384 image
scale_w = 384 / 640
scale_h = 384 / 480
scale_factor = scale_w * scale_h

# Add a column for resized bbox area
df["resized_bbox_area"] = df["bbox_area"] * scale_factor

# Filter out rows where resized bbox area is too small
df = df[df["resized_bbox_area"] >= minSize].copy()

label_to_bin_count = df.groupby("label_id")["depth_bin"].nunique()
valid_label_ids = label_to_bin_count[label_to_bin_count >= filterLen].index
df_step1 = df[df["label_id"].isin(valid_label_ids)]

# STEP 2: Apply final logic with bbox_area sorting
filtered_label_to_dist_imgset = {}
rows = []

for label_id, label_df in tqdm(df_step1.groupby("label_id"), desc="Applying Step 2 filtering"):
    label_name = label_df["label_name"].iloc[0]
    
    # Group by depth_bin → DataFrame (not just index)
    bin_map = {b: df for b, df in label_df.groupby("depth_bin")}

    # Step 1: Remove bins with < min_img
    cleaned_bins = {b: df for b, df in bin_map.items() if len(df) >= min_per_bin}

    # Step 2: Cap bins > max_img by sorting on bbox_area
    resized_bins = {}
    for b, df_bin in cleaned_bins.items():
        if len(df_bin) > max_per_bin:
            df_bin = df_bin.sort_values("bbox_area", ascending=False).head(max_per_bin)
        resized_bins[b] = df_bin

    bin_items = list(resized_bins.items())
    final_bins = dict(bin_items)

    # Step 3: Keep only labels with ≥ min_depth_bins
    if len(final_bins) >= 3:
        filtered_label_to_dist_imgset[label_id] = final_bins
        for df_bin in final_bins.values():
            rows.append(df_bin)

# Combine filtered rows into final DataFrame
df_step2 = pd.concat(rows, ignore_index=True)

# Summary Histogram After Step 2
plt.figure(figsize=(12, 5))
df_step2["label_name"].value_counts().plot(kind="bar", color="orange")
plt.title("Step 2: New Image Count per Label (After Bin Filtering)")
plt.xlabel("Label Name")
plt.ylabel("New Image Count")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.grid(axis="y")
plt.savefig("afterFiltering_summary.png")
plt.close()


# Save filtered CSV
df_step2["depth_bin"] = df_step2["depth_bin"].astype(str)
columns_to_save = [
    "new_image_id", "old_image_id", "label_id", "instance_id",
    "label_name", "depth", "bbox_width", "bbox_height", "bbox_area",
    "resized_bbox_area", "depth_bin"
]
df_step2.to_csv("bbox_metadata_filtered.csv", index=False, columns=columns_to_save)

print(f"Saved filtered CSV to 'bbox_metadata_filtered.csv' with {len(df_step2)} rows")

# Final Stats
print("\nSummary")
print(f"Resize → Labels left: {df['label_id'].nunique()}, Unique new images: {df['new_image_id'].nunique()}")
print(f"Step 1 → Labels left: {df_step1['label_id'].nunique()}, Unique new images: {df_step1['new_image_id'].nunique()}")
print(f"Step 2 → Labels left: {df_step2['label_id'].nunique()}, Unique new images: {df_step2['new_image_id'].nunique()}")
