import os
import sys
import glob
import json
import time 
import random
from collections import Counter, defaultdict

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from evaluation.constants import *


# --- Configuration for Sampling and Reporting ---
SAMPLE_SIZE = 64  # Max number of JSONs to check per experiment
MISSING_THRESHOLD = 0.20 # Report if a metric is missing in >20% of samples


def check_metric_structure(data, expected_structure, path_prefix=""):
    """
    Recursively checks if the data from a JSON file contains the expected keys.
    This function can handle nested dictionaries of expected keys.

    Args:
        data (dict): The loaded JSON data.
        expected_structure (dict): A dictionary defining the expected keys and sub-structures.
        path_prefix (str): Used for recursive calls to build a human-readable path to missing keys.

    Returns:
        list: A list of missing metric paths (e.g., "Hessian_SAIL_Metric.visualizations.t50").
    """
    missing = []
    for key, sub_structure in expected_structure.items():
        # Check if the top-level key exists in the current data dictionary
        if key not in data:
            missing.append(f"{path_prefix}{key} (entire block)")
            continue

        # If the expected structure for this key is a dictionary, recurse deeper
        if isinstance(sub_structure, dict):
            missing.extend(check_metric_structure(data[key], sub_structure, path_prefix=f"{path_prefix}{key}."))
        # If it's a list, check for the presence of each item in the list as a key
        elif isinstance(sub_structure, list):
            for sub_key in sub_structure:
                if sub_key not in data[key]:
                    missing.append(f"{path_prefix}{key}.{sub_key}")
        # If it's None, we just check for the key's existence, which we already did
        elif sub_structure is None:
            continue
            
    return missing


def analyze_experiment_files(file_paths, expected_metrics):
    """
    Analyzes a sample of metric files for a single experiment.

    Args:
        file_paths (list): A list of paths to metric JSON files.
        expected_metrics (dict): The expected metric structure to validate against.

    Returns:
        dict: A dictionary of problematic metrics and their missing rates.
    """
    if not file_paths:
        return {}

    # Sample files to avoid checking thousands of files per experiment
    num_to_sample = min(len(file_paths), SAMPLE_SIZE)
    sampled_files = random.sample(file_paths, num_to_sample)
    
    missing_counter = Counter()
    for json_path in sampled_files:
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            missing_in_file = check_metric_structure(data, expected_metrics)
            missing_counter.update(missing_in_file)
        except (json.JSONDecodeError, IOError, TypeError):
            # In a summary view, we are less concerned about single file errors
            continue

    # Determine which metrics are missing frequently enough to be a problem
    problematic_metrics = {}
    for metric, count in missing_counter.items():
        missing_rate = count / num_to_sample
        if missing_rate > MISSING_THRESHOLD:
            problematic_metrics[metric] = f"{count}/{num_to_sample} ({missing_rate:.0%})"
            
    return problematic_metrics

def per_experiment():
    print("="*60)
    print("Running Experiment Completeness Check...")
    print("="*60)

    missing_experiments = []
    experiment_reports = defaultdict(dict)

    # --- 1. Loop through all expected experiments ---
    print("1. Scanning all expected experiments...")
    for method in ALL_METHODS:
        required_datasets = EXPECTED_DATASETS.get(method, EXPECTED_DATASETS["default"])
        for dataset in required_datasets:
            exp_dir = os.path.join(OUTPUT_ROOT, method, dataset)
            
            if not os.path.exists(exp_dir):
                missing_experiments.append(f"Directory not found for: {method}/{dataset}")
                continue

            # --- 2. & 3. Analyze Per-Seed and Cross-Seed Metrics ---
            all_json_files_in_exp = glob.glob(os.path.join(exp_dir, 'prompt_*.json'), recursive=True)
            per_seed_files = []
            cross_seed_files = []

            # Categorize files based on their names, similar to analysis.py
            for f_path in all_json_files_in_exp:
                basename = os.path.basename(f_path)
                # Use a flexible check for any file containing '_cross_seed'
                if '_cross_seed' in basename:
                    cross_seed_files.append(f_path)
                # Ensure per-seed files match the expected format 'prompt_xxxx_xx_'
                elif '_metrics.json' in basename:
                    parts = basename.split('_')
                    if len(parts) > 2 and parts[1].isdigit() and len(parts[1]) == 4 and parts[2].isdigit() and len(parts[2]) == 2:
                        per_seed_files.append(f_path)

            # The rest of the function now uses these correctly populated lists
            per_seed_problems = analyze_experiment_files(per_seed_files, EXPECTED_PER_SEED_METRICS)
            if per_seed_problems:
                experiment_reports[(method, dataset)]['per_seed'] = per_seed_problems

            cross_seed_problems = analyze_experiment_files(cross_seed_files, EXPECTED_CROSS_SEED_METRICS)
            if cross_seed_problems:
                experiment_reports[(method, dataset)]['cross_seed'] = cross_seed_problems


    # --- 4. Generate and Print the Final Report ---
    print("\n" + "="*20 + " Completeness Report " + "="*20 + "\n")
    
    if not missing_experiments and not experiment_reports:
        print("✅ All expected experiments and metrics appear to be complete. Great job!")
    else:
        if missing_experiments:
            print("❌ MISSING EXPERIMENT DIRECTORIES (Rerun required):")
            for item in missing_experiments:
                print(f"  - {item}")
            print("-" * 50)
        
        if experiment_reports:
            print("\n❌ EXPERIMENTS WITH INCOMPLETE METRICS (Rerun recommended):")
            for (method, dataset), issues in sorted(experiment_reports.items()):
                print(f"\n  ▶ Experiment: {method}/{dataset}")
                if 'per_seed' in issues:
                    print("    - Per-Seed Issues (from _metrics.json):")
                    for metric, rate in issues['per_seed'].items():
                        print(f"      - Metric '{metric}' missing in {rate} of samples.")
                if 'cross_seed' in issues:
                    print("    - Cross-Seed Issues (from _cross_seed.json):")
                    for metric, rate in issues['cross_seed'].items():
                        print(f"      - Metric '{metric}' missing in {rate} of samples.")
            print("-" * 50)
            
    print("\nReport finished.")
    print("="*60)



def per_file():
    print("="*60)
    print("Running Experiment Completeness Check...")
    print("="*60)

    # Initialize lists to hold identified issues
    missing_experiments = []
    incomplete_per_seed_metrics = []
    incomplete_cross_seed_metrics = []

    # --- 1. Check for missing experiment directories ---
    print("1. Checking for missing experiment directories...")
    for method in ALL_METHODS:
        required_datasets = EXPECTED_DATASETS.get(method, EXPECTED_DATASETS["default"])
        for dataset in required_datasets:
            exp_dir = os.path.join(OUTPUT_ROOT, method, dataset)
            if not os.path.exists(exp_dir):
                missing_experiments.append(f"Directory not found for: {method}/{dataset}")
    print(f"   -> Found {len(missing_experiments)} missing experiment directories.")

    # --- 2. Check for metric completeness in per-seed files ---
    print("\n2. Checking per-seed metric files...")
    # Corrected glob pattern based on run_sail.py output naming convention:
    # e.g., output/sail/laion_memorized/prompt_0000_00_A_high-quality_pho_metrics.json
    per_seed_glob_pattern = os.path.join(OUTPUT_ROOT, '**', 'prompt_????_??_*_metrics.json')
    all_per_seed_jsons = glob.glob(per_seed_glob_pattern, recursive=True)
    print(f"   -> Found {len(all_per_seed_jsons)} per-seed metric files to check.")

    for json_path in all_per_seed_jsons:
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            # Use the recursive checker function to validate structure
            missing_in_file = check_metric_structure(data, EXPECTED_PER_SEED_METRICS)
            if missing_in_file:
                incomplete_per_seed_metrics.append({"file": json_path, "missing": missing_in_file})
        except (json.JSONDecodeError, IOError) as e:
            print(f"   [Warning] Could not read or parse {json_path}: {e}")
            
    print(f"   -> Found {len(incomplete_per_seed_metrics)} incomplete per-seed files.")

    # --- 3. Check for metric completeness in cross-seed files ---
    print("\n3. Checking cross-seed metric files...")
    # Corrected glob pattern based on run_sail.py output:
    # e.g., output/sail/laion_memorized/prompt_0000_A_high-quality_pho_cross_seed.json
    cross_seed_glob_pattern = os.path.join(OUTPUT_ROOT, '**', 'prompt_*_cross_seed.json')
    all_cross_seed_jsons = glob.glob(cross_seed_glob_pattern, recursive=True)
    print(f"   -> Found {len(all_cross_seed_jsons)} cross-seed metric files to check.")
    
    for json_path in all_cross_seed_jsons:
        try:
            with open(json_path, 'r') as f:
                data = json.load(f)
            # Use the recursive checker function
            missing_in_file = check_metric_structure(data, EXPECTED_CROSS_SEED_METRICS)
            if missing_in_file:
                incomplete_cross_seed_metrics.append({"file": json_path, "missing": missing_in_file})
        except (json.JSONDecodeError, IOError) as e:
             print(f"   [Warning] Could not read or parse {json_path}: {e}")
    
    print(f"   -> Found {len(incomplete_cross_seed_metrics)} incomplete cross-seed files.")

    # --- 4. Generate and Print the Final Report ---
    print("\n" + "="*20 + " Completeness Report " + "="*20 + "\n")
    
    if not missing_experiments and not incomplete_per_seed_metrics and not incomplete_cross_seed_metrics:
        print("✅ All expected experiments and metrics are present. Great job!")
    else:
        if missing_experiments:
            print("❌ MISSING EXPERIMENTS (Directory not found):")
            for item in missing_experiments:
                print(f"  - {item}")
            print("-" * 40)
        
        if incomplete_per_seed_metrics:
            print("\n❌ INCOMPLETE PER-SEED FILES (Missing metrics):")
            for item in incomplete_per_seed_metrics:
                # Use relpath for cleaner output
                print(f"  - File: {os.path.relpath(item['file'])}")
                print(f"    Missing: {', '.join(item['missing'])}")
            print("-" * 40)
        
        if incomplete_cross_seed_metrics:
            print("\n❌ INCOMPLETE CROSS-SEED FILES (Missing metrics):")
            for item in incomplete_cross_seed_metrics:
                print(f"  - File: {os.path.relpath(item['file'])}")
                print(f"    Missing: {', '.join(item['missing'])}")
            print("-" * 40)
            
    print("\nReport finished.")
    print("="*60)



if __name__ == "__main__":
    per_experiment()