#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import glob
from collections import defaultdict

# Base path configuration relative to this script's location
# Script is in: supplementary_materials/ablation_study/utils/
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# Path to supplementary_materials/ablation_study/
ABLATION_STUDY_DIR = os.path.dirname(SCRIPT_DIR)
# Path to supplementary_materials/
SUPPLEMENTARY_MATERIALS_DIR = os.path.dirname(ABLATION_STUDY_DIR)

COMBINED_RESULTS_DIR = os.path.join(ABLATION_STUDY_DIR, "combined_results")
# Baseline experiment results are expected here, as output by deepresearch_bench_ablation_study.py
# when run with an exp_suffix like "ablation_exp_v2" or "Baseline"
# For this script, we specifically target the directory named "ablation_exp_v2" or whatever run_ablation.sh produces as the baseline.
# Let's make it more flexible or assume a specific baseline directory name.
# For now, assuming the baseline results are in a subdir of COMBINED_RESULTS_DIR named "Baseline" or "ablation_exp_v2".
# The run_ablation.sh script now creates subdirs named e.g., "Baseline", "No_Weights", etc.
# This script should process the one considered the "full baseline".
BASELINE_INPUT_SUBDIR_NAME = "Baseline" # This should match a directory created by run_ablation.sh
BASELINE_INPUT_DIR = os.path.join(COMBINED_RESULTS_DIR, BASELINE_INPUT_SUBDIR_NAME)

# Criteria data file, relative to supplementary_materials directory
CRITERIA_DATA_FILE = os.path.join(SUPPLEMENTARY_MATERIALS_DIR, "data", "criteria_data", "criteria.jsonl")

# Target models
TARGET_MODELS = ["openai", "gemini", "grok", "perplexity"]

# Dimension mapping (if needed, but scores usually use English keys from LLM)
DIMENSIONS_ZH = {
    "comprehensiveness": "全面性",
    "insight": "深入性", 
    "instruction_following": "指令遵循能力",
    "readability": "可读性"
}
DIMENSIONS_EN = ["comprehensiveness", "insight", "instruction_following", "readability"]

def ensure_dir_exists(directory):
    os.makedirs(directory, exist_ok=True)

def load_jsonl(file_path):
    data = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError:
                        print(f"JSONDecodeError: Failed to parse line in {file_path}: {line[:70]}...")
    except Exception as e:
        print(f"Error: Failed to read file {file_path}: {str(e)}")
    return data

def load_criteria_weights_map():
    """Loads criteria data and creates a map from prompt_id to its criteria object."""
    criteria_map = {}
    criteria_data_list = load_jsonl(CRITERIA_DATA_FILE)
    if not criteria_data_list:
        print(f"Warning: Criteria data file not found or empty: {CRITERIA_DATA_FILE}")
        return {}
    for item in criteria_data_list:
        # Assuming 'id' or 'prompt' can serve as a key. The main script uses 'prompt' as key.
        # The criteria file usually has an 'id' that matches query 'id', and a 'prompt' string.
        # For consistency, let's try to key by 'prompt' if available, else 'id'.
        key = item.get('prompt') or str(item.get('id'))
        if key:
            criteria_map[key] = item
    print(f"Successfully loaded {len(criteria_map)} criteria entries from {CRITERIA_DATA_FILE}")
    return criteria_map

def save_jsonl(data, file_path):
    try:
        ensure_dir_exists(os.path.dirname(file_path))
        with open(file_path, 'w', encoding='utf-8') as f:
            for item in data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        print(f"Successfully saved {len(data)} items to {file_path}")
    except Exception as e:
        print(f"Error: Failed to save file {file_path}: {str(e)}")

def normalize_score_0_100(score):
    """Ensure score is within 0-100. If it looks like 0-10 (common for Likert), scale it."""
    try:
        s = float(score)
        if 0 <= s <= 10: # Heuristic: if score is small, assume it's on a 1-10 or 0-10 scale
            return s * 10
        return max(0, min(100, s)) # Clamp to 0-100 otherwise
    except (ValueError, TypeError):
        return 0.0 # Default for unparseable scores

# --- Score Calculation Logic --- 
# This section mirrors the logic from the original process_ablation_baseline.py
# It recalculates scores based on different weighting schemes using the raw LLM output JSON.

def get_dim_scores_from_llm_output(llm_output_json, criteria_data_for_prompt, use_criteria_item_weights):
    target_dim_scores = {}
    reference_dim_scores = {}
    is_pointwise = True # Default assumption

    for dim_key in DIMENSIONS_EN: # Iterate through expected dimensions
        dim_evaluations = llm_output_json.get(dim_key, [])
        if not dim_evaluations:
            target_dim_scores[dim_key] = 0.0
            reference_dim_scores[dim_key] = 0.0
            continue

        dim_criteria_list = []
        if use_criteria_item_weights and criteria_data_for_prompt and dim_key in criteria_data_for_prompt.get("criterions", {}):
            dim_criteria_list = criteria_data_for_prompt["criterions"][dim_key]

        current_target_scores = []
        current_reference_scores = []
        current_weights = []

        for i, eval_item in enumerate(dim_evaluations):
            article_1_s = eval_item.get("article_1_score")
            article_2_s = eval_item.get("article_2_score")
            target_s = eval_item.get("target_score") # For pointwise

            item_weight = 1.0
            if use_criteria_item_weights and i < len(dim_criteria_list):
                item_weight = dim_criteria_list[i].get("weight", 1.0)
            
            if item_weight <= 0: item_weight = 1.0 # Avoid zero or negative weights

            if article_1_s is not None and article_2_s is not None: # Pairwise
                is_pointwise = False
                try:
                    current_target_scores.append(float(article_1_s) * item_weight)
                    current_reference_scores.append(float(article_2_s) * item_weight)
                    current_weights.append(item_weight)
                except ValueError:
                    pass # Skip if scores are not numbers
            elif target_s is not None: # Pointwise
                try:
                    current_target_scores.append(float(target_s) * item_weight)
                    current_weights.append(item_weight)
                except ValueError:
                    pass 
        
        total_weight = sum(current_weights) if current_weights else 0
        if total_weight > 0:
            target_dim_scores[dim_key] = sum(current_target_scores) / total_weight
            if not is_pointwise:
                reference_dim_scores[dim_key] = sum(current_reference_scores) / total_weight
            else:
                reference_dim_scores[dim_key] = 0.0 # No reference in pointwise
        else: # Fallback to simple average if no weights or items
            target_dim_scores[dim_key] = np.mean([float(s.get("article_1_score",0) if not is_pointwise else s.get("target_score",0)) for s in dim_evaluations]) if dim_evaluations else 0.0
            if not is_pointwise:
                reference_dim_scores[dim_key] = np.mean([float(s.get("article_2_score",0)) for s in dim_evaluations]) if dim_evaluations else 0.0
            else:
                reference_dim_scores[dim_key] = 0.0
                
    return target_dim_scores, reference_dim_scores, is_pointwise

def calculate_final_score_from_dim_scores(target_dim_scores, reference_dim_scores, criteria_data_for_prompt, use_dimension_weights, is_pointwise):
    overall_target_score = 0.0
    overall_reference_score = 0.0
    
    dim_total_weight = 0.0

    for dim_key in DIMENSIONS_EN:
        dim_weight = 1.0
        if use_dimension_weights and criteria_data_for_prompt:
            dim_weight = criteria_data_for_prompt.get("dimension_weight", {}).get(dim_key, 1.0)
        if dim_weight <= 0 : dim_weight = 1.0

        overall_target_score += target_dim_scores.get(dim_key, 0.0) * dim_weight
        if not is_pointwise:
            overall_reference_score += reference_dim_scores.get(dim_key, 0.0) * dim_weight
        dim_total_weight += dim_weight

    if dim_total_weight > 0:
        final_target_avg = overall_target_score / dim_total_weight
        final_reference_avg = 0.0 if is_pointwise else overall_reference_score / dim_total_weight
    else: # Fallback if all dimension weights are zero (should not happen with default 1.0)
        final_target_avg = np.mean(list(target_dim_scores.values())) if target_dim_scores else 0.0
        final_reference_avg = 0.0 if is_pointwise else (np.mean(list(reference_dim_scores.values())) if reference_dim_scores else 0.0)

    if is_pointwise:
        return normalize_score_0_100(final_target_avg) # Pointwise score is just the target's score
    else:
        if final_target_avg + final_reference_avg > 0:
            return 100 * final_target_avg / (final_target_avg + final_reference_avg)
        return 0.0

def recalculate_scores_for_item(item_data, criteria_map, use_criteria_item_weights, use_dimension_weights):
    prompt_key = item_data.get("prompt") or str(item_data.get("id"))
    llm_output_json = item_data.get("llm_output_json_parsed")

    if not llm_output_json or not isinstance(llm_output_json, dict):
        # print(f"Skipping item ID {item_data.get('id', 'N/A')} due to missing or invalid llm_output_json_parsed.")
        return 0.0 # Cannot calculate if no raw scores

    # Vanilla prompt check (already has overall_score from LLM)
    if item_data.get("use_vanilla_prompt") is True:
        return normalize_score_0_100(item_data.get("overall_score", 0.0))

    criteria_data_for_prompt = criteria_map.get(prompt_key)
    if not criteria_data_for_prompt and (use_criteria_item_weights or use_dimension_weights):
        # print(f"Warning: Criteria data not found for prompt key '{prompt_key}' when weights are enabled. Proceeding with unweighted for this item.")
        # Effectively disable weights if data is missing
        _use_criteria_item_weights = False if use_criteria_item_weights else use_criteria_item_weights
        _use_dimension_weights = False if use_dimension_weights else use_dimension_weights
    else:
        _use_criteria_item_weights = use_criteria_item_weights
        _use_dimension_weights = use_dimension_weights

    target_dim_scores, ref_dim_scores, is_pointwise = get_dim_scores_from_llm_output(
        llm_output_json, 
        criteria_data_for_prompt, 
        _use_criteria_item_weights
    )
    
    # If the original run was pointwise (no_reference=True), then ref_dim_scores will be zero/empty.
    # The flag `item_data.get("use_reference") is False` would confirm this.
    # The `is_pointwise` determined from LLM output structure is more robust here.
    effective_is_pointwise = item_data.get("use_reference") is False or is_pointwise

    final_score = calculate_final_score_from_dim_scores(
        target_dim_scores, 
        ref_dim_scores, 
        criteria_data_for_prompt, 
        _use_dimension_weights,
        effective_is_pointwise
    )
    return final_score

def process_baseline_data_with_varied_weights(criteria_map):
    """Processes baseline data, recalculating scores with different weighting schemes."""
    print(f"Processing baseline data from: {BASELINE_INPUT_DIR}")
    ensure_dir_exists(COMBINED_RESULTS_DIR) # Ensure output parent dir exists

    model_raw_data = {}
    for model_name in TARGET_MODELS:
        # Baseline files are directly in BASELINE_INPUT_DIR, e.g., gemini.jsonl
        file_path = os.path.join(BASELINE_INPUT_DIR, f"{model_name}.jsonl")
        if os.path.exists(file_path):
            model_raw_data[model_name] = load_jsonl(file_path)
            print(f"  Loaded {len(model_raw_data[model_name])} raw baseline entries for model: {model_name}")
        else:
            print(f"  Warning: Baseline data file not found for model {model_name} at {file_path}")
            model_raw_data[model_name] = []

    weighting_schemes = {
        "baseline_fully_weighted": {"use_criteria_item_weights": True, "use_dimension_weights": True},
        "no_dimension_weights": {"use_criteria_item_weights": True, "use_dimension_weights": False},
        "no_criteria_item_weights": {"use_criteria_item_weights": False, "use_dimension_weights": True},
        "no_weights_at_all": {"use_criteria_item_weights": False, "use_dimension_weights": False}
    }
    # Output filenames in COMBINED_RESULTS_DIR:
    # ablation_exp_v2.jsonl (for fully weighted, matching original name if this script is the source)
    # no_dim_weights_ablation_exp_v2.jsonl
    # no_criteria_weights_ablation_exp_v2.jsonl
    # no_weights_ablation_exp_v2.jsonl (for fully unweighted)
    output_filenames = {
        "baseline_fully_weighted": "ablation_exp_v2.jsonl", # Matches the original baseline name for this processing
        "no_dimension_weights": "no_dim_weights_ablation_exp_v2.jsonl",
        "no_criteria_item_weights": "no_criteria_weights_ablation_exp_v2.jsonl",
        "no_weights_at_all": "no_weights_ablation_exp_v2.jsonl"
    }

    for scheme_name, scheme_params in weighting_schemes.items():
        print(f"\n  Processing with weighting scheme: {scheme_name}")
        scheme_output_data = defaultdict(lambda: {"prompt": "", "overall_scores": {}, "costs": {}})
        error_count = 0

        for model_name, items_list in model_raw_data.items():
            if not items_list:
                continue
            for item_data in items_list:
                prompt_id = str(item_data.get('id', '') or item_data.get('prompt', ''))
                if not prompt_id: continue
                
                if "error" in item_data: # Skip items that already had an error during generation
                    # scheme_output_data[prompt_id]["overall_scores"][model_name] = 0 # Or some error marker
                    # scheme_output_data[prompt_id]["costs"][model_name] = item_data.get("cost", 0.0)
                    continue

                try:
                    recalculated_score = recalculate_scores_for_item(item_data, criteria_map, 
                                                                 scheme_params["use_criteria_item_weights"], 
                                                                 scheme_params["use_dimension_weights"])
                    scheme_output_data[prompt_id]["prompt"] = item_data.get("prompt", "")
                    scheme_output_data[prompt_id]["overall_scores"][model_name] = recalculated_score
                    scheme_output_data[prompt_id]["costs"][model_name] = item_data.get("cost", 0.0) # Preserve cost
                except Exception as e:
                    print(f"    ERROR processing item ID {item_data.get('id', 'N/A')} for model {model_name} with scheme {scheme_name}: {e}")
                    error_count += 1
                    scheme_output_data[prompt_id]["overall_scores"][model_name] = 0 # Error score
                    scheme_output_data[prompt_id]["costs"][model_name] = item_data.get("cost", 0.0)

        final_output_list = []
        for pid, pdata in scheme_output_data.items():
            final_output_list.append({
                "prompt_id": pid,
                "prompt": pdata["prompt"],
                "overall_scores": pdata["overall_scores"],
                "costs": pdata["costs"]
            })
        
        output_file_path = os.path.join(COMBINED_RESULTS_DIR, output_filenames[scheme_name])
        if final_output_list:
            save_jsonl(final_output_list, output_file_path)
            if error_count > 0:
                print(f"    Scheme {scheme_name} completed with {error_count} item processing errors.")
        else:
             print(f"    No data processed for scheme {scheme_name}. Output file not written.")

def main():
    criteria_map = load_criteria_weights_map()
    if not criteria_map:
        print("Critical Error: Criteria data could not be loaded. Aborting baseline processing.")
        return
    
    if not os.path.exists(BASELINE_INPUT_DIR):
        print(f"Critical Error: Baseline input directory not found: {BASELINE_INPUT_DIR}. Aborting.")
        print(f"Please ensure that the results from the 'Baseline' or 'ablation_exp_v2' run of run_ablation.sh are present in that directory.")
        return
        
    process_baseline_data_with_varied_weights(criteria_map)
    print("\nBaseline processing with varied weights finished.")

if __name__ == "__main__":
    main() 