#!/usr/bin/env python3
"""
Cache Management Utility for BIRD Text-to-SQL Evaluation
"""

import os
import argparse
import shutil
from pathlib import Path


def clear_ground_truth_cache(cache_root: str = "cache"):
    """Clear the ground truth cache"""
    gt_cache_dir = os.path.join(cache_root, "ground_truth")
    gt_cache_file = os.path.join(gt_cache_dir, "ground_truth_cache.json")
    
    if os.path.exists(gt_cache_file):
        os.remove(gt_cache_file)
        print(f"✅ Cleared ground truth cache: {gt_cache_file}")
        return True
    else:
        print(f"ℹ️  Ground truth cache not found: {gt_cache_file}")
        return False


def clear_prediction_caches(cache_root: str = "cache"):
    """Clear all prediction caches"""
    pred_cache_dir = os.path.join(cache_root, "predictions")
    
    if os.path.exists(pred_cache_dir):
        cache_files = [f for f in os.listdir(pred_cache_dir) if f.endswith('_cache.json')]
        
        if cache_files:
            for cache_file in cache_files:
                file_path = os.path.join(pred_cache_dir, cache_file)
                os.remove(file_path)
                print(f"✅ Cleared prediction cache: {cache_file}")
            return len(cache_files)
        else:
            print(f"ℹ️  No prediction caches found in: {pred_cache_dir}")
            return 0
    else:
        print(f"ℹ️  Prediction cache directory not found: {pred_cache_dir}")
        return 0


def clear_specific_prediction_cache(run_name: str, cache_root: str = "cache"):
    """Clear a specific prediction cache by run name"""
    pred_cache_file = os.path.join(cache_root, "predictions", f"{run_name}_cache.json")
    
    if os.path.exists(pred_cache_file):
        os.remove(pred_cache_file)
        print(f"✅ Cleared prediction cache: {run_name}_cache.json")
        return True
    else:
        print(f"ℹ️  Prediction cache not found: {pred_cache_file}")
        return False


def clear_all_caches(cache_root: str = "cache"):
    """Clear all caches"""
    if os.path.exists(cache_root):
        shutil.rmtree(cache_root)
        print(f"✅ Cleared entire cache directory: {cache_root}")
        return True
    else:
        print(f"ℹ️  Cache directory not found: {cache_root}")
        return False


def list_caches(cache_root: str = "cache"):
    """List all existing caches with sizes"""
    print(f"📁 Cache directory: {cache_root}")
    
    if not os.path.exists(cache_root):
        print("ℹ️  Cache directory does not exist")
        return
    
    # Ground truth cache
    gt_cache_file = os.path.join(cache_root, "ground_truth", "ground_truth_cache.json")
    if os.path.exists(gt_cache_file):
        size_mb = os.path.getsize(gt_cache_file) / (1024 * 1024)
        print(f"  🗂️  Ground truth cache: {size_mb:.2f} MB")
    else:
        print("  ❌ Ground truth cache: Not found")
    
    # Prediction caches
    pred_cache_dir = os.path.join(cache_root, "predictions")
    if os.path.exists(pred_cache_dir):
        cache_files = [f for f in os.listdir(pred_cache_dir) if f.endswith('_cache.json')]
        
        if cache_files:
            print(f"  🗂️  Prediction caches ({len(cache_files)}):")
            total_size = 0
            for cache_file in sorted(cache_files):
                file_path = os.path.join(pred_cache_dir, cache_file)
                size_mb = os.path.getsize(file_path) / (1024 * 1024)
                total_size += size_mb
                run_name = cache_file.replace('_cache.json', '')
                print(f"    - {run_name}: {size_mb:.2f} MB")
            print(f"    Total prediction caches: {total_size:.2f} MB")
        else:
            print("  ❌ Prediction caches: None found")
    else:
        print("  ❌ Prediction cache directory: Not found")
    
    # Total cache size
    if os.path.exists(cache_root):
        total_size = sum(
            os.path.getsize(os.path.join(dirpath, filename))
            for dirpath, dirnames, filenames in os.walk(cache_root)
            for filename in filenames
        ) / (1024 * 1024)
        print(f"\n📊 Total cache size: {total_size:.2f} MB")


def main():
    parser = argparse.ArgumentParser(description='Cache Management for BIRD Evaluation')
    parser.add_argument('--cache_root', default='cache', help='Cache root directory (default: cache)')
    
    # Action arguments (mutually exclusive)
    action_group = parser.add_mutually_exclusive_group(required=True)
    action_group.add_argument('--list', action='store_true', help='List all caches with sizes')
    action_group.add_argument('--clear-all', action='store_true', help='Clear all caches')
    action_group.add_argument('--clear-gt', action='store_true', help='Clear ground truth cache only')
    action_group.add_argument('--clear-pred', action='store_true', help='Clear all prediction caches')
    action_group.add_argument('--clear-run', help='Clear specific prediction cache by run name')
    
    args = parser.parse_args()
    
    if args.list:
        list_caches(args.cache_root)
    
    elif args.clear_all:
        if input("⚠️  Clear ALL caches? This cannot be undone. (y/N): ").lower() == 'y':
            clear_all_caches(args.cache_root)
        else:
            print("❌ Cancelled")
    
    elif args.clear_gt:
        if input("⚠️  Clear ground truth cache? This will require rebuilding on next evaluation. (y/N): ").lower() == 'y':
            clear_ground_truth_cache(args.cache_root)
        else:
            print("❌ Cancelled")
    
    elif args.clear_pred:
        if input("⚠️  Clear all prediction caches? (y/N): ").lower() == 'y':
            count = clear_prediction_caches(args.cache_root)
            if count > 0:
                print(f"✅ Cleared {count} prediction cache(s)")
        else:
            print("❌ Cancelled")
    
    elif args.clear_run:
        run_name = args.clear_run
        if input(f"⚠️  Clear prediction cache for run '{run_name}'? (y/N): ").lower() == 'y':
            clear_specific_prediction_cache(run_name, args.cache_root)
        else:
            print("❌ Cancelled")


if __name__ == "__main__":
    main()