import json
from scipy.stats import spearmanr
import numpy as np
import re 
from pathlib import Path
from datetime import datetime
from collections import defaultdict # For easier nested dictionary creation

def extract_frame_number(filename):
    if filename is None: return -1
    match = re.search(r'im_(\d+)\.jpg', filename)
    return int(match.group(1)) if match else -1

def calculate_voc_for_trajectory(traj_data_entry):
    parsed_outputs = traj_data_entry.get("parsed_gemini_outputs", [])
    valid_predictions_with_indices = []
    for item in parsed_outputs:
        original_filename = item.get("original_filename")
        predicted_percentage = item.get("predicted_percentage")
        if original_filename is not None and predicted_percentage is not None:
            frame_idx = extract_frame_number(original_filename)
            if frame_idx != -1:
                valid_predictions_with_indices.append({
                    "original_idx": frame_idx,
                    "percentage": predicted_percentage
                })
    if not valid_predictions_with_indices: return None, 0
    valid_predictions_with_indices.sort(key=lambda x: x["original_idx"])
    predicted_values_ordered = [item["percentage"] for item in valid_predictions_with_indices]
    ground_truth_chronological_order_values = [item["original_idx"] for item in valid_predictions_with_indices]
    if len(predicted_values_ordered) < 2: return None, len(predicted_values_ordered)
    if np.std(predicted_values_ordered) == 0: return 0.0, len(predicted_values_ordered)
    if np.std(ground_truth_chronological_order_values) == 0 and len(set(ground_truth_chronological_order_values)) == 1:
        return 0.0, len(predicted_values_ordered)
    try:
        voc_coefficient, _ = spearmanr(predicted_values_ordered, ground_truth_chronological_order_values)
        return (0.0 if np.isnan(voc_coefficient) else voc_coefficient), len(predicted_values_ordered)
    except Exception: return None, len(predicted_values_ordered)

def process_single_json_file_for_average_voc(json_file_path):
    # print(f"  Processing file for VOC averaging: {json_file_path.name}") # Can be verbose
    try:
        with open(json_file_path, 'r') as f:
            all_traj_results = json.load(f)
    except (FileNotFoundError, json.JSONDecodeError):
        print(f"  Error loading/decoding {json_file_path.name}")
        return None

    total_voc_sum = 0
    num_trajectories_with_voc = 0
    for trajectory_result in all_traj_results:
        gemini_response = trajectory_result.get("gemini_response_text", "")
        if gemini_response and ("Skipped:" in gemini_response or "An error occurred" in gemini_response):
            continue
        voc, num_frames = calculate_voc_for_trajectory(trajectory_result)
        if voc is not None and num_frames >=2 : # Only consider valid VOCs from sufficient frames
            total_voc_sum += voc
            num_trajectories_with_voc += 1
    
    return total_voc_sum / num_trajectories_with_voc if num_trajectories_with_voc > 0 else None

def main():
    # --- Configuration ---
    final_results_base_dir = Path("results")  # Base directory containing subdirs like 'ood_emb'
    file_pattern = "*_CUMULATIVE.json"
    summary_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    # Save the summary in the parent of final_results_base_dir or in final_results_base_dir itself
    overall_summary_filename = final_results_base_dir / f"VOC_SUMMARY_BY_DIR_AND_TYPE_{summary_timestamp}.json"

    if not final_results_base_dir.is_dir():
        print(f"Error: Base results directory not found: {final_results_base_dir}")
        return

    # Structure for the final JSON output
    # Example: summary_output["script"]["few_shot"]["filename1"] = 0.5
    summary_output = defaultdict(lambda: {"few_shot": {}, "zero_shot": {}})

    # Iterate through subdirectories in final_results_base_dir
    for subdir in final_results_base_dir.iterdir():
        if subdir.is_dir():
            print(f"\nProcessing subdirectory: {subdir.name}")
            subdir_key = subdir.name
            
            json_files_to_process = sorted(list(subdir.glob(file_pattern)))
            if not json_files_to_process:
                print(f"  No files matching '{file_pattern}' found in '{subdir.name}'.")
                continue

            print(f"  Found {len(json_files_to_process)} files to process in {subdir.name}.")

            few_shot_voc_scores_in_subdir = []
            zero_shot_voc_scores_in_subdir = []

            for json_file_path in json_files_to_process:
                avg_voc_for_file = process_single_json_file_for_average_voc(json_file_path)
                
                # Use the filename (without _CUMULATIVE.json) as the specific method/run key
                method_run_key = json_file_path.name.replace("_CUMULATIVE.json", "")
                
                if avg_voc_for_file is not None:
                    rounded_avg_voc = round(avg_voc_for_file, 4)
                    if method_run_key.startswith("gvl_fixed_fewshot_results_"):
                        summary_output[subdir_key]["few_shot"][method_run_key] = rounded_avg_voc
                        few_shot_voc_scores_in_subdir.append(avg_voc_for_file)
                    elif method_run_key.startswith("gvl_results_"):
                        summary_output[subdir_key]["zero_shot"][method_run_key] = rounded_avg_voc
                        zero_shot_voc_scores_in_subdir.append(avg_voc_for_file)
                    else:
                        print(f"    Uncategorized file (based on prefix): {method_run_key}, VOC: {rounded_avg_voc}")
                        # Optionally, put uncategorized into a default category or log them
                        if "uncategorized" not in summary_output[subdir_key]:
                            summary_output[subdir_key]["uncategorized"] = {}
                        summary_output[subdir_key]["uncategorized"][method_run_key] = rounded_avg_voc
                else:
                    # Store None if VOC couldn't be calculated for the file
                    if method_run_key.startswith("gvl_fixed_fewshot_results_"):
                         summary_output[subdir_key]["few_shot"][method_run_key] = None
                    elif method_run_key.startswith("gvl_results_"):
                        summary_output[subdir_key]["zero_shot"][method_run_key] = None
                    else:
                        if "uncategorized" not in summary_output[subdir_key]:
                            summary_output[subdir_key]["uncategorized"] = {}
                        summary_output[subdir_key]["uncategorized"][method_run_key] = None


            # Calculate and add average VOC for few-shot and zero-shot within the subdirectory
            if few_shot_voc_scores_in_subdir:
                avg_fs_voc = sum(few_shot_voc_scores_in_subdir) / len(few_shot_voc_scores_in_subdir)
                summary_output[subdir_key]["average_voc_few_shot"] = round(avg_fs_voc, 4)
            if zero_shot_voc_scores_in_subdir:
                avg_zs_voc = sum(zero_shot_voc_scores_in_subdir) / len(zero_shot_voc_scores_in_subdir)
                summary_output[subdir_key]["average_voc_zero_shot"] = round(avg_zs_voc, 4)
            
            # Clean up empty categories if they were created but not populated
            if not summary_output[subdir_key]["few_shot"]: del summary_output[subdir_key]["few_shot"]
            if not summary_output[subdir_key]["zero_shot"]: del summary_output[subdir_key]["zero_shot"]
            if "uncategorized" in summary_output[subdir_key] and not summary_output[subdir_key]["uncategorized"]:
                del summary_output[subdir_key]["uncategorized"]
            if not summary_output[subdir_key]: # If subdir ended up empty
                del summary_output[subdir_key]


    summary_output["overall_calculation_timestamp"] = datetime.now().isoformat()

    print("\n--- Overall VOC Summary by Directory and Type ---")
    # Pretty print the summary (or just rely on the saved JSON file)
    print(json.dumps(summary_output, indent=2)) 
    
    try:
        with open(overall_summary_filename, 'w') as f_out:
            json.dump(summary_output, f_out, indent=2)
        print(f"\nOverall VOC summary saved to: {overall_summary_filename}")
    except Exception as e:
        print(f"Error saving overall VOC summary: {e}")

if __name__ == "__main__":
    main()