import os
import json
import argparse
import numpy as np
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
import glob


PROMPT_FOLLOWING = "prompt_following"
CONSISTENCY = "consistency"
OVERALL = "overall"
SCORE_CATEGORIES = [PROMPT_FOLLOWING, CONSISTENCY, OVERALL]


# Default benchmark path - override with --benchmark_file argument
BENCHMARK_JSONL_PATH = "./data/editscore_bench.jsonl"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--result_file", type=str, default=None, help="Path to the result JSON file (single file)")
    parser.add_argument("--result_files", type=str, nargs='+', default=None, help="Paths to multiple result JSON files for averaging")
    parser.add_argument("--avg_n", type=int, default=1, help="Number of inference passes to average (1=single pass, 4=avg4)")
    parser.add_argument("--auto_latest", action="store_true", help="Automatically find and use the latest timestamped result file(s)")
    parser.add_argument("--benchmark_file", type=str, default=BENCHMARK_JSONL_PATH, help="Path to benchmark JSONL file")
    args = parser.parse_args()
    
    # Validate arguments
    if args.result_file is None and args.result_files is None:
        parser.error("Either --result_file or --result_files must be provided")
    if args.result_file and args.result_files:
        parser.error("Cannot use both --result_file and --result_files")
    
    return args


def load_results(result_files, avg_n=1):
    """Load results from JSON file(s) and average if multiple files provided - optimized with parallel loading"""
    if isinstance(result_files, str):
        result_files = [result_files]
    
    if len(result_files) != avg_n and avg_n > 1:
        print(f"Warning: Expected {avg_n} files for avg{avg_n}, but got {len(result_files)} files")
    
    # Load all result files in parallel
    def load_single_file(result_file):
        with open(result_file, 'r', encoding='utf-8') as f:
            results = json.load(f)
        print(f"Loaded {len(results)} results from {result_file}")
        return results
    
    all_results = []
    with ThreadPoolExecutor(max_workers=min(8, len(result_files))) as executor:
        futures = [executor.submit(load_single_file, rf) for rf in result_files]
        for future in as_completed(futures):
            all_results.append(future.result())
    
    # If only one file, return directly
    if len(all_results) == 1:
        return all_results[0]
    
    # Average scores across multiple files using numpy for faster computation
    print(f"Averaging scores across {len(all_results)} inference passes...")
    averaged_results = {}
    
    # Get all keys from first file
    keys = set(all_results[0].keys())
    
    # Vectorized score averaging
    for key in keys:
        # Check if key exists in all files
        if not all(key in results for results in all_results):
            print(f"Warning: Key {key} not found in all result files, skipping")
            continue
        
        # Collect all scores for this key
        scores_arrays = {dim: [] for dim in SCORE_CATEGORIES}
        for results in all_results:
            scores = results[key]["scores"]
            for dim in SCORE_CATEGORIES:
                if dim in scores:
                    scores_arrays[dim].append(scores[dim])
        
        # Vectorized averaging
        averaged_scores = {
            dim: float(np.mean(scores_arrays[dim])) 
            for dim in SCORE_CATEGORIES if scores_arrays[dim]
        }
        
        # Use other fields from first result
        averaged_results[key] = {
            "scores": averaged_scores,
            "task_type": all_results[0][key].get("task_type", ""),
            "dimension": all_results[0][key].get("dimension", ""),
            "instruction": all_results[0][key].get("instruction", ""),
        }
    
    print(f"Averaged {len(averaged_results)} results")
    return averaged_results


def calculate_statistics(results, benchmark_dataset):
    """
    Calculate statistics based on the EditScore benchmark format - optimized with parallel processing
    Each sample in benchmark has 2 keys, and we compare their scores ONLY on the dimension specified in the benchmark
    """
    
    # Group results by task_type and dimension
    task_results = defaultdict(lambda: defaultdict(list))
    
    # Pre-allocate lists for better performance
    missing_keys = []
    
    # Process samples in parallel for faster computation
    def process_benchmark_sample(sample):
        """Process a single benchmark sample - parallelizable"""
        key1, key2 = sample["key"]
        task_type = sample["task_type"]
        dimension = sample["dimension"]  # The dimension to compare on
        
        # Get scores for both outputs
        if key1 not in results or key2 not in results:
            return None, (key1, key2)  # Missing keys
        
        scores1 = results[key1]["scores"]
        scores2 = results[key2]["scores"]
        
        # IMPORTANT: Only compare on the dimension specified in the benchmark
        # This is the correct logic matching the official EditScore implementation
        if dimension not in scores1 or dimension not in scores2:
            return None, (key1, key2)  # Missing dimension
        
        score1 = scores1[dimension]
        score2 = scores2[dimension]
        
        # Record: 1 if correct (score1 > score2), 0 otherwise
        correct = 1 if score1 > score2 else 0
        
        sample_result = {
            "correct": correct,
            "score1": score1,
            "score2": score2,
            "key1": key1,
            "key2": key2,
        }
        
        return (task_type, dimension, sample_result), None
    
    # Parallel processing with ThreadPoolExecutor
    print("Calculating statistics with parallel workers...")
    with ThreadPoolExecutor(max_workers=min(16, len(benchmark_dataset))) as executor:
        futures = [executor.submit(process_benchmark_sample, sample) 
                  for sample in benchmark_dataset]
        
        for future in as_completed(futures):
            result, missing = future.result()
            if missing:
                missing_keys.append(missing)
                print(f"Warning: Missing results for keys {missing[0]} or {missing[1]}")
            elif result:
                task_type, dimension, sample_result = result
                task_results[task_type][dimension].append(sample_result)
    
    if missing_keys:
        print(f"Total missing key pairs: {len(missing_keys)}")
    
    # Calculate accuracy for each task_type and dimension using numpy
    accuracies = defaultdict(dict)
    all_scores = defaultdict(list)
    
    for task_type in task_results:
        for dim in SCORE_CATEGORIES:
            results_list = task_results[task_type][dim]
            if len(results_list) == 0:
                accuracies[task_type][dim] = 0.0
                continue
            
            # Vectorized accuracy calculation
            correct_array = np.array([r["correct"] for r in results_list])
            accuracy = float(np.mean(correct_array))
            accuracies[task_type][dim] = accuracy
            
            # Collect all scores for overall statistics
            all_scores[dim].extend([r["score1"] for r in results_list])
            all_scores[dim].extend([r["score2"] for r in results_list])
    
    return accuracies, all_scores, task_results


def print_results(accuracies, all_scores):
    """Print results in a formatted table"""
    
    # Get all task types
    task_types = sorted(accuracies.keys())
    
    # Calculate averages
    avg_accuracies = {}
    for dim in SCORE_CATEGORIES:
        scores = [accuracies[task][dim] for task in task_types if dim in accuracies[task]]
        avg_accuracies[dim] = np.mean(scores) if scores else 0.0
    
    # Print header
    print("\n" + "="*100)
    print("EditScore Benchmark Results")
    print("="*100)
    
    # Print per-task results
    print("\nPer-Task Accuracies:")
    print("-"*100)
    print(f"{'Task Type':<25} {'Prompt Following':>15} {'Consistency':>15} {'Overall':>15}")
    print("-"*100)
    
    for task in task_types:
        pf = accuracies[task].get(PROMPT_FOLLOWING, 0.0)
        cons = accuracies[task].get(CONSISTENCY, 0.0)
        overall = accuracies[task].get(OVERALL, 0.0)
        print(f"{task:<25} {pf:>15.3f} {cons:>15.3f} {overall:>15.3f}")
    
    print("-"*100)
    print(f"{'Average':<25} {avg_accuracies[PROMPT_FOLLOWING]:>15.3f} {avg_accuracies[CONSISTENCY]:>15.3f} {avg_accuracies[OVERALL]:>15.3f}")
    print("-"*100)
    
    # Print grouped results
    groups = {
        'object': ['subject-add', 'subject-remove', 'subject-replace'],
        'appearance': ['color_alter', 'material_alter', 'style_change', 'tone_transfer'],
        'scene': ['background_change', 'extract'],
        'advanced': ['ps_human', 'text_change', 'motion_change', 'compose'],
    }
    
    print("\nGrouped Accuracies:")
    print("-"*100)
    print(f"{'Group':<25} {'Prompt Following':>15} {'Consistency':>15} {'Overall':>15}")
    print("-"*100)
    
    for group_name, group_tasks in groups.items():
        # Filter tasks that exist in our results
        valid_tasks = [t for t in group_tasks if t in accuracies]
        
        if not valid_tasks:
            continue
        
        pf_scores = [accuracies[t].get(PROMPT_FOLLOWING, 0.0) for t in valid_tasks]
        cons_scores = [accuracies[t].get(CONSISTENCY, 0.0) for t in valid_tasks]
        overall_scores = [accuracies[t].get(OVERALL, 0.0) for t in valid_tasks]
        
        pf_mean = np.mean(pf_scores)
        cons_mean = np.mean(cons_scores)
        overall_mean = np.mean(overall_scores)
        
        print(f"{group_name:<25} {pf_mean:>15.3f} {cons_mean:>15.3f} {overall_mean:>15.3f}")
    
    print("-"*100)
    
    # Print score statistics
    print("\nScore Statistics (0-10 scale):")
    print("-"*100)
    print(f"{'Dimension':<25} {'Min':>10} {'Max':>10} {'Mean':>10} {'Std':>10}")
    print("-"*100)
    
    for dim in SCORE_CATEGORIES:
        scores = all_scores[dim]
        if len(scores) == 0:
            continue
        print(f"{dim:<25} {np.min(scores):>10.3f} {np.max(scores):>10.3f} {np.mean(scores):>10.3f} {np.std(scores):>10.3f}")
    
    print("-"*100)
    print("\n")


def save_detailed_results(result_file, accuracies, all_scores, task_results):
    """Save detailed results to a summary file"""
    output_file = result_file.replace('.json', '_summary.json')
    
    summary = {
        "accuracies": {task: dict(accs) for task, accs in accuracies.items()},
        "score_statistics": {
            dim: {
                "min": float(np.min(scores)),
                "max": float(np.max(scores)),
                "mean": float(np.mean(scores)),
                "std": float(np.std(scores)),
            } for dim, scores in all_scores.items() if len(scores) > 0
        },
        "average_accuracies": {
            dim: float(np.mean([accuracies[task].get(dim, 0.0) for task in accuracies]))
            for dim in SCORE_CATEGORIES
        }
    }
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    
    print(f"Detailed results saved to {output_file}")


def main():
    args = parse_args()
    
    # Auto-find latest files if requested
    if args.auto_latest:
        if args.result_file:
            # Expand glob pattern and find latest
            matching_files = glob.glob(args.result_file)
            if not matching_files:
                print(f"Error: No files found matching pattern: {args.result_file}")
                return
            # Sort by modification time, get latest
            latest_file = max(matching_files, key=os.path.getmtime)
            print(f"Auto-found latest result file: {latest_file}")
            args.result_file = latest_file
        elif args.result_files:
            # Expand glob patterns for each file
            all_matching = []
            for pattern in args.result_files:
                all_matching.extend(glob.glob(pattern))
            if not all_matching:
                print(f"Error: No files found matching patterns: {args.result_files}")
                return
            # Sort all files by modification time
            all_matching.sort(key=os.path.getmtime, reverse=True)
            # Take the latest N files where N = avg_n
            latest_files = all_matching[:args.avg_n]
            print(f"Auto-found {len(latest_files)} latest result file(s):")
            for f in latest_files:
                print(f"  - {f}")
            args.result_files = latest_files
    
    # Determine result files
    if args.result_file:
        result_files = [args.result_file]
        output_file = args.result_file
    else:
        result_files = args.result_files
        # For multiple files, generate avg summary name
        # E.g., results_pass1.json -> results_avg4_summary.json
        base_path = result_files[0].replace('_pass1.json', '').replace('.json', '')
        output_file = f"{base_path}_avg{args.avg_n}_summary.json"
    
    # Load results (with averaging if multiple files)
    results = load_results(result_files, avg_n=args.avg_n)
    
    # Load benchmark from JSONL file
    benchmark_path = args.benchmark_file
    print(f"Loading benchmark from {benchmark_path}...")
    benchmark_dataset = []
    with open(benchmark_path, 'r', encoding='utf-8') as f:
        for line in f:
            benchmark_dataset.append(json.loads(line.strip()))
    print(f"Loaded {len(benchmark_dataset)} benchmark samples")
    
    # Calculate statistics
    accuracies, all_scores, task_results = calculate_statistics(results, benchmark_dataset)
    
    # Print results
    if args.avg_n > 1:
        print(f"\n{'='*50}")
        print(f"Results with avg{args.avg_n} ({len(result_files)} passes)")
        print(f"{'='*50}")
    print_results(accuracies, all_scores)
    
    # Save detailed results
    save_detailed_results(output_file, accuracies, all_scores, task_results)


if __name__ == "__main__":
    main()

