import json
import os
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
from tqdm import tqdm  # Add tqdm for progress tracking

# --- Load label ---
with open('label.json', 'r') as f:
    labels = json.load(f)

# --- Settings ---
root_path = "/fs/ess/PAS2099/sooyoung/perception_system_v2_local/kitti_analysis/FINAL"  # Replace with your desired root path
os.makedirs(root_path, exist_ok=True)  # Ensure the directory exists

valid_classes = {'car', 'pedestrian', 'cyclist', 'van', 'person_sitting'}
pair_list = [
    ('car', 'car'),
    ('pedestrian', 'pedestrian'),
    ('cyclist', 'cyclist'),
    ('van', 'van'),
    ('car', 'pedestrian'),
    ('car', 'cyclist'),
    ('car', 'van'),
    ('pedestrian', 'cyclist'),
    ('pedestrian', 'van'),
    ('cyclist', 'van')
]
depth_threshold = 2  # 50 cm
max_obj_depth = 100
depth_diff_conf = 0.3
same_class_depth_diff_conf = 0.5

# --- Step 1: Filter valid objects ---
image_objs = {}
for img_id, objs in tqdm(labels.items(), desc="(1/2) Filtering valid objects"):
    valid_objs = []
    for idx, obj in enumerate(objs):
        cls = obj['class'].lower()
        if cls == 'person_sitting':
            cls = 'pedestrian'
        # Include only objects with valid closest_depth
        if cls in valid_classes and obj['occluded'] == 0 and obj['closest_depth'] is not None and obj['closest_depth'] <= max_obj_depth:
            valid_objs.append({
                'id': idx,  # local id within image
                'class': cls,
                'bbox_2d': obj['bbox_2d'],
                'closest_depth': obj['closest_depth']
            })
    if valid_objs:
        image_objs[img_id] = valid_objs

# --- Step 2: Pair Matching ---
pair_counts = Counter()
total_pairs = 0
image_pairs = {}
depth_differences = []  # Store depth differences for histogram

for img_id, objs in tqdm(image_objs.items(), desc="(2/2) Matching pairs"):
    pairs = []
    n = len(objs)
    for i in range(n):
        for j in range(i+1, n):
            cls_i = objs[i]['class']
            cls_j = objs[j]['class']
            depth_i = objs[i]['closest_depth']
            depth_j = objs[j]['closest_depth']
            # Check valid pair ignoring order
            if (cls_i, cls_j) in pair_list or (cls_j, cls_i) in pair_list:
                depth_diff = abs(depth_i - depth_j)
                closer_obj = objs[i] if depth_i < depth_j else objs[j]
                depth_diff_threshold = closer_obj['closest_depth'] * depth_diff_conf
                
                if cls_i == cls_j:
                    depth_diff_threshold = closer_obj['closest_depth'] * same_class_depth_diff_conf
                
                if depth_diff >= depth_threshold and depth_diff >= depth_diff_threshold:
                    
                    pairs.append((objs[i]['id'], objs[j]['id']))
                    pair_key = (cls_i, cls_j) if (cls_i, cls_j) in pair_list else (cls_j, cls_i)
                    pair_counts[pair_key] += 1
                    total_pairs += 1
                    depth_differences.append(depth_diff)  # Collect depth difference
    if pairs:
        image_pairs[img_id] = {
            'objects': objs,
            'pairs': pairs
        }

print(f"Total pairs: {total_pairs}")

# --- Step 3: Plot pair counts ---
plt.figure()

# Sort pair_counts by count in descending order
sorted_pair_counts = sorted(pair_counts.items(), key=lambda x: x[1], reverse=True)

# Plot the sorted pair counts
for pair, cnt in sorted_pair_counts:
    plt.bar(f"{pair[0]}-{pair[1]}", cnt)

plt.xticks(rotation=45)
plt.title(f"Pair Counts (Min Depth Difference: {depth_threshold}m)")
plt.tight_layout()
plt.savefig(os.path.join(root_path, "pair_counts.png"))
plt.close()

# --- Step 4: Plot depth difference histogram ---
plt.figure()
plt.hist(depth_differences, bins=20, edgecolor='black')  # Adjust bins as needed
plt.title(f"Depth Difference Histogram (Min Depth Difference: {depth_threshold}m)")
plt.xlabel("Depth Difference (m)")
plt.ylabel("Frequency")
plt.tight_layout()
plt.savefig(os.path.join(root_path, "depth_difference_histogram.png"))
plt.close()

# --- Step 5: Count objects selected as pairs ---
object_counts = Counter()

for img_id, data in image_pairs.items():
    objs = data['objects']
    pairs = data['pairs']
    selected_ids = set([obj_id for pair in pairs for obj_id in pair])  # Collect all object IDs in pairs
    for obj in objs:
        if obj['id'] in selected_ids:
            object_counts[obj['class']] += 1

# Plot object counts
plt.figure()
sorted_object_counts = sorted(object_counts.items(), key=lambda x: x[1], reverse=True)  # Sort by count
for obj_class, count in sorted_object_counts:
    plt.bar(obj_class, count)

plt.xticks(rotation=45)
plt.title(f"Object Counts in Valid Pairs (Min Depth Difference: {depth_threshold}m)")
plt.xlabel("Object Class")
plt.ylabel("Count")
plt.tight_layout()
plt.savefig(os.path.join(root_path, "object_counts_in_pairs.png"))
plt.close()

# --- Step 6: Save result ---
with open(os.path.join(root_path, 'filtered_label_with_pairs.json'), 'w') as f:
    json.dump(image_pairs, f, indent=2)
