import json
import os

def identify_confidence_field(annotation):
   """
   Identify the confidence field name in annotation
   Returns: field name or None if not found
   """
   possible_fields = ['score', 'confidence', 'conf', 'prob', 'probability']
   
   for field in possible_fields:
       if field in annotation:
           return field
   return None

def is_valid_confidence(value):
   """
   Check if confidence value is valid (numeric and not None)
   Returns: True if valid, False otherwise
   """
   if value is None:
       return False
   try:
       float(value)
       return True
   except (ValueError, TypeError):
       return False

def filter_annotations_by_confidence(annotations, threshold=0.5):
   """
   Filter annotations based on confidence threshold
   Returns: (filtered_annotations, stats)
   """
   if not annotations:
       return [], {'total': 0, 'filtered': 0, 'removed': 0}
   
   filtered_annotations = []
   removed_count = 0
   no_confidence_count = 0
   
   # Try to identify confidence field from first annotation
   confidence_field = None
   for ann in annotations:
       if isinstance(ann, dict):
           confidence_field = identify_confidence_field(ann)
           if confidence_field:
               break
   
   if not confidence_field:
       print(f"Warning: No confidence field found. Keeping all {len(annotations)} annotations.")
       return annotations, {'total': len(annotations), 'filtered': len(annotations), 'removed': 0}
   
   print(f"Using confidence field: '{confidence_field}'")
   
   # Filter annotations
   for ann in annotations:
       if not isinstance(ann, dict):
           filtered_annotations.append(ann)
           continue
       
       if confidence_field not in ann:
           no_confidence_count += 1
           filtered_annotations.append(ann)  # Keep annotations without confidence
           continue
       
       confidence = ann[confidence_field]
       
       if not is_valid_confidence(confidence):
           no_confidence_count += 1
           filtered_annotations.append(ann)  # Keep annotations with invalid confidence
           continue
       
       # Apply threshold filter
       if float(confidence) >= threshold:
           filtered_annotations.append(ann)
       else:
           removed_count += 1
   
   if no_confidence_count > 0:
       print(f"Warning: {no_confidence_count} annotations without valid confidence values were kept.")
   
   stats = {
       'total': len(annotations),
       'filtered': len(filtered_annotations),
       'removed': removed_count
   }
   
   return filtered_annotations, stats

def process_dict_format(data, threshold=0.5):
   """
   Process COCO format data when it's a dictionary
   Returns: (processed_data, stats)
   """
   # Common annotation keys in COCO format
   possible_annotation_keys = ['annotations', 'annotation', 'anns', 'detections']
   
   annotation_key = None
   for key in possible_annotation_keys:
       if key in data and isinstance(data[key], list):
           annotation_key = key
           break
   
   if not annotation_key:
       print("Warning: No annotation list found in dictionary format.")
       return data, {'total': 0, 'filtered': 0, 'removed': 0}
   
   print(f"Found annotations under key: '{annotation_key}'")
   
   # Filter annotations
   filtered_annotations, stats = filter_annotations_by_confidence(
       data[annotation_key], threshold
   )
   
   # Update data with filtered annotations
   processed_data = data.copy()
   processed_data[annotation_key] = filtered_annotations
   
   return processed_data, stats

def process_list_format(data, threshold=0.5):
   """
   Process COCO format data when it's a list
   Returns: (processed_data, stats)
   """
   # Assume the list contains annotations directly
   filtered_annotations, stats = filter_annotations_by_confidence(data, threshold)
   
   return filtered_annotations, stats

def filter_coco_annotations(json_file_path, output_file_path=None, threshold=0.5):
   """
   Main function to filter COCO annotations by confidence
   """
   # Load JSON file
   try:
       with open(json_file_path, 'r') as f:
           data = json.load(f)
   except FileNotFoundError:
       print(f"Error: {json_file_path} not found!")
       return
   except json.JSONDecodeError:
       print(f"Error: Invalid JSON format in {json_file_path}")
       return
   
   print(f"Processing: {json_file_path}")
   print(f"Confidence threshold: {threshold}")
   print("-" * 50)
   
   # Process based on data type
   if isinstance(data, dict):
       print("Detected format: Dictionary (standard COCO)")
       processed_data, stats = process_dict_format(data, threshold)
   elif isinstance(data, list):
       print("Detected format: List")
       processed_data, stats = process_list_format(data, threshold)
   else:
       print("Error: Unsupported data format. Expected dict or list.")
       return
   
   # Print statistics
   print("\n" + "=" * 50)
   print("FILTERING RESULTS")
   print("=" * 50)
   print(f"Total annotations: {stats['total']}")
   print(f"Kept annotations: {stats['filtered']}")
   print(f"Removed annotations: {stats['removed']}")
   if stats['total'] > 0:
       removal_rate = (stats['removed'] / stats['total']) * 100
       print(f"Removal rate: {removal_rate:.1f}%")
   
   # Save filtered data
   if output_file_path is None:
       # Generate output filename
       name, ext = os.path.splitext(json_file_path)
       output_file_path = f"{name}_filtered{ext}"
   
   try:
       with open(output_file_path, 'w') as f:
           json.dump(processed_data, f, indent=2)
       print(f"\nFiltered data saved to: {output_file_path}")
   except Exception as e:
       print(f"Error saving file: {e}")

def main():
   # Get input parameters
#    json_file_path = input("Enter JSON file path: ").strip()
   json_file_path = "/data/xxx/segmentation/CuVLER/datasets/imagenet/annotations-official/imagenet_val_votecut_kmax_3_tuam_0.2.json"
   json_file_path = "/data/xxx/segmentation/CutLER/coler_train/imagenet_val/inference/coco_instances_results.json"
   json_file_path = "/data/xxx/segmentation/CutLER/coler_eval/cls_agnostic_imagenet/inference_base/coco_instances_results.json"
   json_file_path = "/data/xxx/segmentation/CutLER/coler_eval/imagenet_val/inference_1/self_train_cutler_r1.json"
   
   if not json_file_path:
       print("No file path provided!")
       return
   
   # Optional: custom threshold
#    threshold_input = input("Enter confidence threshold (default 0.5): ").strip()
   threshold = 0.5
#    if threshold_input:
#        try:
#            threshold = float(threshold_input)
#        except ValueError:
#            print("Invalid threshold value, using default 0.5")
   
   # Optional: custom output path
#    output_path = input("Enter output file path (press Enter for auto-generated): ").strip()
   output_path = "/data/xxx/segmentation/CuVLER/datasets/imagenet/annotations-official/imagenet_val_votecut_filter.json"
   output_path = "/data/xxx/segmentation/CutLER/coler_train/imagenet_val/inference/coco_instances_results_filter.json"
   output_path = "/data/xxx/segmentation/CutLER/coler_eval/cls_agnostic_imagenet/inference_base/coler_self_train_r1.json"
   
   output_path = output_path if output_path else None
   
   # Process the file
   filter_coco_annotations(json_file_path, output_path, threshold)

if __name__ == "__main__":
   main()