#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import glob
from collections import defaultdict
import re

# Base path configuration relative to this script's location
# Script is in: supplementary_materials/ablation_study/utils/
# Combined results are in: supplementary_materials/ablation_study/combined_results/
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# Path to supplementary_materials/ablation_study/
ABLATION_STUDY_DIR = os.path.dirname(SCRIPT_DIR) 
COMBINED_RESULTS_DIR = os.path.join(ABLATION_STUDY_DIR, "combined_results")

# Target models for which scores are expected
TARGET_MODELS = ["openai", "gemini", "grok", "perplexity"]

def ensure_dir_exists(directory):
    """Ensure directory exists."""
    os.makedirs(directory, exist_ok=True)

def load_jsonl(file_path):
    """Load data from JSONL file."""
    data = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError:
                        print(f"Failed to parse line in {file_path}: {line[:50]}...")
    except Exception as e:
        print(f"Failed to read file {file_path}: {str(e)}")
    return data

def save_jsonl(data, file_path):
    """Save data to JSONL file."""
    try:
        ensure_dir_exists(os.path.dirname(file_path))
        with open(file_path, 'w', encoding='utf-8') as f:
            for item in data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        print(f"Successfully saved {len(data)} items to {file_path}")
    except Exception as e:
        print(f"Failed to save file {file_path}: {str(e)}")

def normalize_score(score):
    """Ensure score is within 0-100 range, scaling if necessary."""
    try:
        score_float = float(score)
        if 0 <= score_float <= 10: # Assuming scores like 1-5 or 1-10 need scaling
            return score_float * 10
        return score_float # Already in 0-100 or other scale not needing this specific normalization
    except (ValueError, TypeError):
        return 0.0 # Default for unparseable scores

def calculate_score(item_data, is_no_reference=False, is_vanilla_prompt=False):
    """Calculate model score based on experiment type."""
    if is_vanilla_prompt:
        score = item_data.get("overall_score", 0.0)
        return normalize_score(score)
    
    elif is_no_reference: # Pointwise scoring, use target_total
        target_total = item_data.get("target_total", 0.0)
        # Fallback for older field name, if necessary
        if target_total == 0.0 and "target_total_weighted_avg" in item_data:
            target_total = item_data.get("target_total_weighted_avg", 0.0)
        return normalize_score(target_total)
        
    else: # Comparative scoring
        target_total = item_data.get("target_total", 0.0)
        reference_total = item_data.get("reference_total", 0.0)
        # Fallbacks for older field names
        if target_total == 0.0 and "target_total_weighted_avg" in item_data:
            target_total = item_data.get("target_total_weighted_avg", 0.0)
        if reference_total == 0.0 and "reference_total_weighted_avg" in item_data:
            reference_total = item_data.get("reference_total_weighted_avg", 0.0)
        
        if target_total + reference_total > 0:
            # Score is already a percentage (target / (target + ref)) * 100 usually
            # The python script `deepresearch_bench_ablation_study.py` produces `target_total` and `reference_total` from `calculate_weighted_scores` or `calculate_unweighted_scores`.
            # If these totals are on a 0-100 scale (or similar), the direct ratio is fine.
            # The original `deepresearch_bench.py` calculated `overall_score = target_total / (target_total + reference_total)` which would be 0-1.
            # Assuming `target_total` and `reference_total` are comparable sums/averages, then the ratio needs scaling by 100.
            # Let's assume the `target_total` and `reference_total` from the ablation script are on a scale where direct comparison and then ratio is meaningful.
            # The ablation script's `process_single_item` might already produce `target_total` and `reference_total` that are scaled (e.g. 0-100 averages).
            # Given no explicit scaling info, assuming the ratio method is intended.
            score = 100 * float(target_total) / (float(target_total) + float(reference_total))
            return score # This is already 0-100, so normalize_score might not be needed here unless inputs are small.
        return 0.0

def process_experiment_dir(exp_dir_path, output_jsonl_path):
    """Process a single experiment directory."""
    print(f"\nProcessing experiment directory: {exp_dir_path}")
    dir_name = os.path.basename(os.path.normpath(exp_dir_path))
    
    # Determine experiment type from directory name conventions
    is_no_reference = "no_reference" in dir_name
    is_vanilla_prompt = "vanilla_prompt" in dir_name
    
    if is_vanilla_prompt:
        print(f"  Directory type: vanilla_prompt (using overall_score from parsed JSON)")
    elif is_no_reference:
        print(f"  Directory type: no_reference/pointwise (using target_total from parsed JSON)")
    else:
        print(f"  Directory type: standard comparative (using target_total / (target+reference) from parsed JSON)")

    # Results for each model (e.g. gemini.jsonl) are expected inside this experiment directory
    # These files are the direct output of `deepresearch_bench_ablation_study.py`
    model_results_files = glob.glob(os.path.join(exp_dir_path, "*.jsonl"))
    
    if not model_results_files:
        print(f"  No model result .jsonl files found in {exp_dir_path}")
        return

    print(f"  Found {len(model_results_files)} potential model result files.")
    
    aggregated_prompt_data = defaultdict(lambda: {"prompt": "", "overall_scores": {}})

    for model_file_path in model_results_files:
        model_file_name = os.path.basename(model_file_path)
        # Infer model name from file name (e.g., "gemini.jsonl" -> "gemini")
        model_name_match = re.match(r"([a-zA-Z0-9_\-]+)\.jsonl", model_file_name)
        if not model_name_match or model_name_match.group(1) not in TARGET_MODELS:
            print(f"    Skipping file (does not match target model pattern): {model_file_name}")
            continue
        
        model_name = model_name_match.group(1)
        print(f"    Processing model file: {model_file_name} for model: {model_name}")
        
        model_run_data = load_jsonl(model_file_path)
        if not model_run_data:
            print(f"      No data loaded from {model_file_path}")
            continue
        
        print(f"      Loaded {len(model_run_data)} entries for model {model_name}.")

        for item_data in model_run_data:
            if "error" in item_data:
                # print(f"Skipping item with error for model {model_name}, prompt ID {item_data.get('id')}")
                continue # Skip items that had processing errors

            prompt_id = str(item_data.get('id', '') or item_data.get('prompt', '')) # Use ID or prompt as key
            if not prompt_id:
                # print(f"Skipping item with no ID/prompt for model {model_name}")
                continue
            
            # The item_data here is the direct output from deepresearch_bench_ablation_study.py
            # It should contain fields like target_total, reference_total, or overall_score based on the run config.
            score = calculate_score(item_data, is_no_reference, is_vanilla_prompt)
            
            aggregated_prompt_data[prompt_id]["prompt"] = item_data.get("prompt", "") # Store prompt text
            aggregated_prompt_data[prompt_id]["overall_scores"][model_name] = score
            
    output_data_list = []
    for prompt_id, data in aggregated_prompt_data.items():
        output_data_list.append({
            "prompt_id": prompt_id,
            "prompt": data["prompt"],
            "overall_scores": data["overall_scores"]
        })
    
    if output_data_list:
        print(f"  Aggregated data for {len(output_data_list)} prompts for experiment '{dir_name}'.")
        save_jsonl(output_data_list, output_jsonl_path)
    else:
        print(f"  No data aggregated for experiment '{dir_name}'. Output file not written.")

def main():
    print(f"COMBINED_RESULTS_DIR is set to: {COMBINED_RESULTS_DIR}")
    ensure_dir_exists(COMBINED_RESULTS_DIR)

    # Find all subdirectories in COMBINED_RESULTS_DIR which represent different experiment configurations
    # Example subdirectories: ablation_exp_v2, No_Weights, Pointwise, etc.
    # These directories are created by run_ablation.sh (or manually)
    # and contain the output of deepresearch_bench_ablation_study.py (e.g. gemini.jsonl, openai.jsonl)
    experiment_subdirs = []
    for item_name in os.listdir(COMBINED_RESULTS_DIR):
        item_path = os.path.join(COMBINED_RESULTS_DIR, item_name)
        if os.path.isdir(item_path):
            # Explicitly exclude ablation_exp_v1 as per original script logic, if that name is used for a subdir
            # For the new structure, it's more about processing any subdir that looks like an experiment output.
            # The original script had a specific exclusion for "ablation_exp_v1".
            # Here, we'll process all subdirectories found.
            # If `ablation_exp_v1` is a directory with raw results, it will be processed.
            # The `run_ablation.sh` now uses descriptive names for suffixes. 
            experiment_subdirs.append(item_path)
            
    if not experiment_subdirs:
        print(f"No experiment subdirectories found in {COMBINED_RESULTS_DIR}. Nothing to process.")
        return

    print(f"Found {len(experiment_subdirs)} experiment directories to process:")
    for d_path in experiment_subdirs:
        print(f"  - {os.path.basename(d_path)}")
    
    for exp_dir_path in experiment_subdirs:
        dir_name = os.path.basename(os.path.normpath(exp_dir_path))
        # Output file will be named after the directory, e.g., Baseline.jsonl, No_Weights.jsonl
        # and placed directly in COMBINED_RESULTS_DIR
        output_jsonl_path = os.path.join(COMBINED_RESULTS_DIR, f"{dir_name}.jsonl")
        process_experiment_dir(exp_dir_path, output_jsonl_path)
    
    print("\nAll experiment directories processed.")

if __name__ == "__main__":
    main() 