import json
import os
import threading
import concurrent.futures
import argparse
from tqdm import tqdm
import logging
import time
import re 

# Assuming utils and prompt are accessible from the project root or PYTHONPATH
from utils.api import AIClient 
from prompt.score_prompt_zh import generate_merged_score_prompt, generate_static_score_prompt, point_wise_score_prompt, vanilla_prompt
# Removed ApiCostCalculator, TokenUsageTracker as cost calculation is removed
from utils.score_calculator import calculate_weighted_scores
from utils.json_extractor import extract_json_from_markdown
from prompt.clean_prompt import clean_article_prompt_zh # Used by clean_single_article
from utils.clean_article import clean_single_article # clean_model is not used directly here

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Fixed configuration parameters with relative paths
# Paths are relative to the `supplementary_materials` directory,
# so for a script in `supplementary_materials/ablation_study/`, `../data` refers to `supplementary_materials/data`
CONFIG = {
    "reference_model": "reference", # Name of the reference model, changed to match filename reference.jsonl
    "query_file": "../data/prompt_data/query.jsonl", # Specific to Chinese for this ablation
    "criteria_file": "../data/criteria_data/criteria.jsonl", # Specific to Chinese
    "raw_data_dir": "../data/test_data/raw_data",
    "cleaned_data_dir": "../data/test_data/cleaned_data",
    "results_dir": "./results",  # Relative to this script's location (supplementary_materials/ablation_study/results)
    "llm_api_model": "yuanshi/heiyan/gemini-2.5-pro-preview-03-25", # LLM model for scoring
    "max_workers": 25,
    "max_retries": 5
}

def load_jsonl(file_path):
    """Load data from JSONL file"""
    data = []
    if not os.path.exists(file_path):
        logging.warning(f"File not found: {file_path}")
        return data
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError as e:
                    logging.warning(f"Error parsing JSON in {file_path}: {e}, line content: {line.strip()}")
    return data

def clean_articles(raw_data_dir, cleaned_data_dir, model_name, clean_agent, 
                   max_workers=CONFIG["max_workers"], max_retries=CONFIG["max_retries"], 
                   limit=None):
    """
    Clean articles for the target model if they haven't been cleaned yet.
    This version only cleans the target_model and assumes reference model articles are pre-cleaned if used.
    It also includes language detection for the cleaning prompt, similar to deepresearch_bench.py.
    """
    if not clean_agent:
        raise ValueError("clean_agent not provided, cannot clean articles")

    raw_file = os.path.join(raw_data_dir, f"{model_name}.jsonl")
    if not os.path.exists(raw_file):
        logging.error(f"Raw data file for model {model_name} not found: {raw_file}")
        return False # Indicate failure

    raw_data_all = load_jsonl(raw_file)
    
    items_to_process = raw_data_all
    if limit is not None and limit > 0:
        items_to_process = raw_data_all[:limit]
        logging.info(f"Applied limit: processing first {len(items_to_process)} articles for model {model_name}")
        if not items_to_process:
            logging.warning(f"No data left after applying limit for model {model_name}")
            return True # No data to process is not a failure of cleaning itself

    raw_ids = {item.get('id') for item in items_to_process if 'id' in item}
    if not raw_ids:
        logging.warning(f"No valid IDs found in the data to process for model {model_name}")
        return True # No items with ID to process

    cleaned_file_path = os.path.join(cleaned_data_dir, f"{model_name}.jsonl")
    
    items_needing_cleaning = []
    existing_cleaned_data = []

    if os.path.exists(cleaned_file_path):
        existing_cleaned_data = load_jsonl(cleaned_file_path)
        cleaned_ids = {item.get('id') for item in existing_cleaned_data if 'id' in item}
        
        missing_ids = raw_ids - cleaned_ids
        if not missing_ids:
            logging.info(f"All articles for model {model_name} (within limit if applied) are already cleaned.")
            return True
            
        items_needing_cleaning = [item for item in items_to_process if 'id' in item and item['id'] in missing_ids]
        logging.info(f"Need to clean {len(missing_ids)} articles for model {model_name}.")
    else:
        items_needing_cleaning = items_to_process
        logging.info(f"Cleaned file for model {model_name} not found. Need to clean all {len(items_needing_cleaning)} articles (within limit).")

    if not items_needing_cleaning:
        logging.info(f"No articles require cleaning for model {model_name}.")
        return True

    logging.info(f"Starting to clean {len(items_needing_cleaning)} articles for model {model_name}...")
    
    cleaned_results_new = []
    # Load query data for language detection (assuming query_file in CONFIG is relevant for language)
    # The ablation script is configured for Chinese, so language might be 'zh'
    # For robustness, let's try to get language from query data like in deepresearch_bench.py
    query_file_path = CONFIG["query_file"] # This is query_zh.jsonl
    all_queries = load_jsonl(query_file_path)
    prompt_to_language = {query.get('prompt'): query.get('language', 'zh') # Default to 'zh' if not specified
                          for query in all_queries if 'prompt' in query}

    with tqdm(total=len(items_needing_cleaning), desc=f"Cleaning {model_name}") as pbar:
        def process_article_for_cleaning(item):
            try:
                current_language = "zh" # Default for this ablation script
                item_prompt = item.get('prompt', '')
                if item_prompt in prompt_to_language:
                    current_language = prompt_to_language[item_prompt]
                
                # clean_single_article is expected to take 'language'
                cleaned_article_data = clean_single_article(
                    item, clean_agent, max_retries=max_retries, language=current_language
                )
                return cleaned_article_data
            except Exception as e:
                logging.error(f"Error cleaning article ID {item.get('id', 'Unknown')}: {e}")
                return None
            finally:
                pbar.update(1)
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(process_article_for_cleaning, item) for item in items_needing_cleaning]
            for future in concurrent.futures.as_completed(futures):
                result = future.result()
                if result:
                    cleaned_results_new.append(result)
    
    os.makedirs(os.path.dirname(cleaned_file_path), exist_ok=True)
    
    # Merge new results with existing ones
    final_cleaned_data = existing_cleaned_data
    existing_ids_for_merge = {item.get('id') for item in final_cleaned_data if 'id' in item}
    
    for item in cleaned_results_new:
        if 'id' in item and item['id'] not in existing_ids_for_merge:
            final_cleaned_data.append(item)
            existing_ids_for_merge.add(item['id']) # Ensure no duplicates if logic somehow re-adds
    
    # Sort results by ID for consistency
    final_cleaned_data.sort(key=lambda x: x.get('id', float('inf')))
    
    with open(cleaned_file_path, 'w', encoding='utf-8') as f:
        for item in final_cleaned_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    logging.info(f"Cleaning for model {model_name} complete. Saved/Updated {cleaned_file_path} with {len(final_cleaned_data)} total articles.")
    
    if len(cleaned_results_new) < len(items_needing_cleaning):
        logging.warning(f"Failed to clean all intended articles for {model_name}: "
                        f"Targeted {len(items_needing_cleaning)}, but cleaned {len(cleaned_results_new)} new ones.")
        return False # Indicate some failures

    return True


def format_criteria_list(criteria_data):
    """Format evaluation criteria list as JSON string, without weight information"""
    criteria_for_prompt = {}
    criterions_dict = criteria_data.get("criterions", {})

    if not criterions_dict:
        logging.warning(f"Criterions not found in criteria_data with ID {criteria_data.get('id', 'unknown')}")
        return "{}" # Return empty JSON object string

    for dim, criterions_list in criterions_dict.items():
        if not isinstance(criterions_list, list):
            logging.warning(f"In criteria_data ID {criteria_data.get('id', 'unknown')}, value for dimension '{dim}' is not a list. Skipping.")
            continue

        criteria_for_prompt[dim] = []
        for crit_item in criterions_list:
            if isinstance(crit_item, dict) and "criterion" in crit_item and "explanation" in crit_item:
                criteria_for_prompt[dim].append({
                    "criterion": crit_item["criterion"],
                    "explanation": crit_item["explanation"]
                })
            else:
                logging.warning(f"In criteria_data ID {criteria_data.get('id', 'unknown')}, invalid criteria format in dimension '{dim}': {crit_item}. Skipping item.")

    try:
        return json.dumps(criteria_for_prompt, ensure_ascii=False, indent=2)
    except TypeError as e:
        logging.error(f"Failed to serialize criteria to JSON for ID {criteria_data.get('id', 'unknown')}: {e}")
        return "{}" # Return empty JSON object string on failure

def calculate_unweighted_scores(llm_output_json):
    """Calculate unweighted scores, all dimensions and items have average weight"""
    result = {
        "target": {"dims": {}, "total": 0.0},
        "reference": {"dims": {}, "total": 0.0} # Kept for structure, may not be populated if no reference
    }
    
    dims_count = 0
    target_total_sum_of_dim_avgs = 0.0
    reference_total_sum_of_dim_avgs = 0.0
    
    is_pointwise = True # Assume point-wise unless article_2_score is found

    # Check if it's vanilla prompt output
    if "overall_score" in llm_output_json:
        logging.debug("Calculating unweighted scores for vanilla prompt output.")
        overall_score = float(llm_output_json.get("overall_score", 0.0))
        result["target"]["total"] = overall_score
        # For vanilla, dims might not be applicable, or could be set to overall if needed by downstream
        return result

    # Standard multi-dimensional output processing
    for dim_name, criteria_eval_list in llm_output_json.items():
        if not isinstance(criteria_eval_list, list) or not criteria_eval_list:
            # Allow empty dimensions if LLM returns them, but they won't contribute to score
            logging.debug(f"Dimension '{dim_name}' is empty or not a list, skipping for unweighted score calculation.")
            continue
            
        dims_count += 1
        target_dim_sum = 0.0
        reference_dim_sum = 0.0
        valid_criteria_count_for_dim = 0
        
        for criteria_item_eval in criteria_eval_list:
            if not isinstance(criteria_item_eval, dict):
                continue
            
            # Check for pairwise scores first
            article_1_score_val = criteria_item_eval.get("article_1_score")
            article_2_score_val = criteria_item_eval.get("article_2_score")
            
            # Check for point-wise score
            target_score_val = criteria_item_eval.get("target_score")

            if article_1_score_val is not None and article_2_score_val is not None:
                is_pointwise = False # Found pairwise scores
                try:
                    target_dim_sum += float(article_1_score_val)
                    reference_dim_sum += float(article_2_score_val)
                    valid_criteria_count_for_dim += 1
                except (ValueError, TypeError):
                    logging.warning(f"Could not parse scores in dimension {dim_name} for item {criteria_item_eval}")
            elif target_score_val is not None: # Pointwise
                try:
                    target_dim_sum += float(target_score_val)
                    valid_criteria_count_for_dim += 1
                except (ValueError, TypeError):
                     logging.warning(f"Could not parse target_score in dimension {dim_name} for item {criteria_item_eval}")
            # else: score not found for this item under this dimension

        if valid_criteria_count_for_dim > 0:
            target_dim_avg = target_dim_sum / valid_criteria_count_for_dim
            result["target"]["dims"][dim_name] = target_dim_avg
            target_total_sum_of_dim_avgs += target_dim_avg
            
            if not is_pointwise: # Only calculate for reference if it's pairwise
                reference_dim_avg = reference_dim_sum / valid_criteria_count_for_dim
                result["reference"]["dims"][dim_name] = reference_dim_avg
                reference_total_sum_of_dim_avgs += reference_dim_avg
        else:
            # If a dimension has no valid scores, it doesn't contribute.
            # Could assign 0 or skip, current logic skips.
            result["target"]["dims"][dim_name] = 0.0 
            if not is_pointwise:
                result["reference"]["dims"][dim_name] = 0.0


    if dims_count > 0:
        result["target"]["total"] = target_total_sum_of_dim_avgs / dims_count
        if not is_pointwise:
            result["reference"]["total"] = reference_total_sum_of_dim_avgs / dims_count
    
    return result

def process_single_item(task_data, target_articles_map, reference_articles_map, criteria_map, llm_client, 
                        lock, pbar, max_retries=3, # Removed llm_model_name, cost_calculator
                        use_dynamic_criteria=True, use_reference=True, use_weights=True, reference_first=False,
                        use_vanilla_prompt=False):
    """Process a single task: configure for different scoring experiments, no cost calculation."""
    task_id = task_data.get('id')
    prompt = task_data.get('prompt')

    # --- Data Retrieval and Validation ---
    if prompt not in target_articles_map:
        logging.warning(f"Target article not found for ID {task_id}: {prompt[:50]}...")
        with lock: pbar.update(1)
        return {"id": task_id, "prompt": prompt, "error": "Target article not found", "use_vanilla_prompt": use_vanilla_prompt}

    if not use_vanilla_prompt and use_reference and prompt not in reference_articles_map:
        logging.warning(f"Reference article not found for ID {task_id}: {prompt[:50]}...")
        with lock: pbar.update(1)
        return {"id": task_id, "prompt": prompt, "error": "Reference article not found", "use_vanilla_prompt": use_vanilla_prompt}

    needs_criteria_data = not use_vanilla_prompt and (use_dynamic_criteria or use_weights)
    if needs_criteria_data and prompt not in criteria_map:
        logging.warning(f"Evaluation criteria not found for ID {task_id}: {prompt[:50]}...")
        with lock: pbar.update(1)
        return {"id": task_id, "prompt": prompt, "error": "Evaluation criteria not found", "use_vanilla_prompt": use_vanilla_prompt}

    target_article_data = target_articles_map[prompt]
    target_article = target_article_data.get('article', '')
    
    reference_article = ""
    if not use_vanilla_prompt and use_reference:
        reference_article_data = reference_articles_map[prompt]
        reference_article = reference_article_data.get('article', '')

    user_llm_prompt = "" # Renamed from user_prompt to avoid conflict with task's prompt
    
    # --- Build LLM Prompt ---
    if use_vanilla_prompt:
        user_llm_prompt = vanilla_prompt.format(
            task_prompt=prompt,
            article=target_article
        )
        # Explicitly set related flags to false for vanilla mode
        use_dynamic_criteria = False 
        use_reference = False
        use_weights = False
        reference_first = False
        
    elif use_reference: # Comparative scoring logic
        criteria_list_str = "{}" # Default to empty JSON string
        if use_dynamic_criteria:
            criteria_data = criteria_map[prompt]
            criteria_list_str = format_criteria_list(criteria_data)
            if criteria_list_str == "{}": # Indicates an issue with formatting or empty criteria
                 logging.error(f"Formatted criteria list is empty or invalid JSON for ID {task_id}. Cannot proceed with dynamic criteria.")
                 with lock: pbar.update(1)
                 return {"id": task_id, "prompt": prompt, "error": "Failed to format criteria list to JSON", "use_vanilla_prompt": use_vanilla_prompt}
        
        # Select prompt template
        prompt_template = generate_merged_score_prompt if use_dynamic_criteria else generate_static_score_prompt

        if reference_first:
            user_llm_prompt = prompt_template.format(
                task_prompt=prompt,
                article_1=reference_article,
                article_2=target_article,
                criteria_list=criteria_list_str # Will be ignored by static_score_prompt if not used in its format string
            )
        else:
            user_llm_prompt = prompt_template.format(
                task_prompt=prompt,
                article_1=target_article,
                article_2=reference_article,
                criteria_list=criteria_list_str
            )
    else: # Pointwise scoring logic (no reference article)
        criteria_list_str = "{}"
        if use_dynamic_criteria:
            criteria_data = criteria_map[prompt]
            criteria_list_str = format_criteria_list(criteria_data)
            if criteria_list_str == "{}":
                 logging.error(f"Formatted criteria list is empty or invalid JSON for ID {task_id} for point-wise dynamic scoring.")
                 with lock: pbar.update(1)
                 return {"id": task_id, "prompt": prompt, "error": "Failed to format criteria list for point-wise scoring", "use_vanilla_prompt": use_vanilla_prompt}

        # For point-wise, we use point_wise_score_prompt (assumed dynamic)
        # If static point-wise is needed, a new prompt template or logic adjustment is required.
        # Current setup implies point_wise_score_prompt is for dynamic criteria.
        if not use_dynamic_criteria:
            logging.warning(f"ID: {task_id} - Configuration 'static_criteria' + 'no_reference' (point-wise) "
                            f"might require a specific static point-wise prompt. Falling back to dynamic point-wise prompt structure, "
                            f"but effective criteria will be fixed by the prompt itself if it's static.")
            # For static, criteria_list_str is usually not passed or ignored.
            # However, point_wise_score_prompt expects criteria_list.
            # If point_wise_score_prompt is inherently dynamic, then static+pointwise is a misconfiguration for it.
            # If it can work with an empty criteria_list for a "static interpretation", we can pass "{}"
            # For now, we assume point_wise_score_prompt is the one to use.
            # If truly static point-wise, the prompt content itself must be static. criteria_list_str would be ignored by such a prompt.
        
        user_llm_prompt = point_wise_score_prompt.format(
            task_prompt=prompt,
            article=target_article,
            criteria_list=criteria_list_str # For static, this may be ignored by prompt template
        )

    # --- LLM Call ---
    llm_response_str = None
    llm_output_json = None
    success = False
    retry_count = 0

    while retry_count < max_retries and not success:
        llm_response_str = None # Reset for each retry
        try:
            # AIClient.generate is expected to take user_prompt and system_prompt
            llm_response_str = llm_client.generate(
                user_prompt=user_llm_prompt, # Use the constructed LLM prompt
                system_prompt="" # Default system prompt
            )

            json_str_extracted = extract_json_from_markdown(llm_response_str)

            if json_str_extracted:
                 llm_output_json = json.loads(json_str_extracted)
                 # Validate JSON structure based on mode
                 if use_vanilla_prompt:
                     if "overall_score" in llm_output_json and "justification" in llm_output_json:
                         success = True
                     else:
                         logging.warning(f"ID: {task_id} - Vanilla Prompt output JSON missing 'overall_score' or 'justification'. Keys: {list(llm_output_json.keys())}. Retrying...")
                         llm_output_json = None # Reset for retry
                 elif use_reference: # Comparative scoring
                     # Basic check for expected dimensions. More robust checks can be added.
                     if all(dim in llm_output_json for dim in ["comprehensiveness", "insight", "instruction_following", "readability"]):
                        success = True # Assuming structure is fine if main keys are present
                     else:
                         logging.warning(f"ID: {task_id} - Comparative scoring JSON missing expected dimensions. Keys: {list(llm_output_json.keys())}. Retrying...")
                         llm_output_json = None
                 else: # Pointwise scoring
                     if all(dim in llm_output_json for dim in ["comprehensiveness", "insight", "instruction_following", "readability"]):
                         success = True
                     else:
                         logging.warning(f"ID: {task_id} - Pointwise scoring JSON missing expected dimensions. Keys: {list(llm_output_json.keys())}. Retrying...")
                         llm_output_json = None
            else:
                 logging.warning(f"ID: {task_id} - Failed to extract valid JSON from LLM response. Response: {llm_response_str[:300]}... Retrying...")
                 llm_output_json = None

        except json.JSONDecodeError as e:
            logging.error(f"ID: {task_id} - Failed to parse extracted JSON. Error: {e}. Extracted: {json_str_extracted[:500]}... Response: {llm_response_str[:500]}... Retrying...")
            llm_output_json = None
        except Exception as e:
            logging.error(f"ID: {task_id} - Error during LLM API call or processing. Error: {e}. Retrying...")
        
        if not success:
            retry_count += 1
            if retry_count < max_retries:
                 logging.info(f"ID: {task_id} - Retrying ({retry_count}/{max_retries})...")
                 time.sleep(1.5 ** retry_count)
            else:
                 logging.error(f"ID: {task_id} - Failed to get valid response/JSON after {max_retries} retries.")

    if not success:
        with lock: pbar.update(1)
        return {
            "id": task_id,
            "prompt": prompt,
            "error": f"Failed to get valid LLM response/JSON after {max_retries} retries.",
            "model_output": llm_response_str or "No response obtained",
            "use_vanilla_prompt": use_vanilla_prompt
        }

    # --- Calculate Scores ---
    scores = {}
    justification_text = "" # For vanilla prompt

    try:
        if use_vanilla_prompt:
            overall_score_val = float(llm_output_json.get("overall_score", 0.0))
            justification_text = llm_output_json.get("justification", "")
            scores = {"target": {"total": overall_score_val}} # Simulate structure
        else:
            criteria_data_for_scoring = criteria_map.get(prompt)
            if criteria_data_for_scoring is None and (use_dynamic_criteria or use_weights): # Needed for weights or dynamic criteria content
                 raise ValueError(f"ID {task_id}: Criteria data not found for scoring, but required for current configuration.")
            
            if use_weights:
                # calculate_weighted_scores expects language param. Assuming 'zh' for this ablation script.
                scores_raw = calculate_weighted_scores(llm_output_json, criteria_data_for_scoring, language="zh")
                if use_reference and reference_first: # Swap if reference was article_1
                    scores = {"target": scores_raw.get("reference", {}), "reference": scores_raw.get("target", {})}
                else:
                    scores = scores_raw
            else: # Not using weights
                scores_raw = calculate_unweighted_scores(llm_output_json) # This needs to handle pointwise/pairwise based on JSON
                if use_reference and reference_first:
                    scores = {"target": scores_raw.get("reference", {}), "reference": scores_raw.get("target", {})}
                else:
                    scores = scores_raw
                    
    except Exception as e:
         logging.error(f"ID: {task_id} - Error calculating scores: {e}", exc_info=True)
         with lock: pbar.update(1)
         return {
            "id": task_id,
            "prompt": prompt,
            "error": f"Error calculating scores: {e}",
            "model_output": llm_response_str,
            "llm_output_json_parsed": llm_output_json,
            "use_vanilla_prompt": use_vanilla_prompt
        }

    # --- Prepare Final Result ---
    final_result = {
        "id": task_id,
        "prompt": prompt,
        "model_output": llm_response_str,
        "llm_output_json_parsed": llm_output_json,
        "use_vanilla_prompt": use_vanilla_prompt
    }

    if use_vanilla_prompt:
        final_result["overall_score"] = scores.get("target", {}).get("total", 0.0)
        final_result["justification"] = justification_text
        final_result["use_dynamic_criteria"] = None # Mark as not applicable
        final_result["use_reference"] = None
        final_result["use_weights"] = None
        final_result["reference_first"] = None
    else:
        final_result["use_dynamic_criteria"] = use_dynamic_criteria
        final_result["use_reference"] = use_reference
        final_result["use_weights"] = use_weights
        final_result["reference_first"] = reference_first if use_reference else None # Only relevant if use_reference
        
        target_scores = scores.get("target", {})
        if "dims" in target_scores:
            for k, v in target_scores["dims"].items():
                final_result[f"target_{k}"] = v
        final_result["target_total"] = target_scores.get("total", 0.0)
        
        if use_reference:
            reference_scores = scores.get("reference", {})
            if "dims" in reference_scores:
                for k, v in reference_scores["dims"].items():
                    final_result[f"reference_{k}"] = v
            final_result["reference_total"] = reference_scores.get("total", 0.0)

    with lock: pbar.update(1)
    return final_result

def main():
    parser = argparse.ArgumentParser(description='Run ablation scoring experiments for model articles.')
    parser.add_argument('target_model', type=str, help='Name of the target model to evaluate (e.g., gemini).')
    
    # Ablation experiment parameters
    parser.add_argument('--no_dynamic_criteria', action='store_true', help='Use static evaluation criteria instead of dynamic criteria from file.')
    parser.add_argument('--no_reference', action='store_true', help='Do not use a reference article for comparative scoring (pointwise scoring).')
    parser.add_argument('--no_weights', action='store_true', help='Do not use weights for scoring; calculate simple average.')
    parser.add_argument('--reference_first', action='store_true', help='In comparative scoring, present reference article as Article 1, target as Article 2.')
    parser.add_argument('--vanilla_prompt', action='store_true', help='Use a simple vanilla prompt for overall quality scoring of the target article.')
    
    # Other optional parameters
    parser.add_argument('--limit', type=int, default=None, help='Limit the number of prompts to process (for testing).')
    parser.add_argument('--skip_cleaning', action='store_true', help='Skip the article cleaning step.')
    parser.add_argument('--exp_suffix', type=str, default="", help='Suffix for the experiment results file/directory, to differentiate configurations.')

    args = parser.parse_args()
    
    # Determine experiment parameters from args
    use_dynamic_criteria = not args.no_dynamic_criteria
    use_reference = not args.no_reference
    use_weights = not args.no_weights
    reference_first = args.reference_first
    use_vanilla_prompt = args.vanilla_prompt
    
    experiment_desc_parts = []
    if use_vanilla_prompt:
        experiment_desc_parts.append("vanilla_prompt")
        # Override other settings for vanilla mode as they are not applicable
        use_dynamic_criteria = False 
        use_reference = False
        use_weights = False
        reference_first = False
    else:
        if not use_dynamic_criteria: experiment_desc_parts.append("static_criteria")
        if not use_reference: experiment_desc_parts.append("no_reference") # i.e. pointwise
        if not use_weights: experiment_desc_parts.append("no_weights")
        if use_reference and reference_first: experiment_desc_parts.append("reference_first")
    
    current_experiment_suffix = "_".join(experiment_desc_parts)
    if args.exp_suffix:
        current_experiment_suffix = f"{current_experiment_suffix}_{args.exp_suffix}" if current_experiment_suffix else args.exp_suffix
    
    if not current_experiment_suffix:
        current_experiment_suffix = "baseline" # Default if all flags are standard (dynamic, reference, weights, target first)

    target_model_name = args.target_model
    # Reference model name from CONFIG, used if use_reference is true and not vanilla.
    reference_model_name = CONFIG["reference_model"] if not use_vanilla_prompt and use_reference else None
    
    query_file_path = CONFIG["query_file"]
    criteria_file_path = CONFIG["criteria_file"] # Loaded even if not always used by all configs, for weights etc.
    raw_data_base_dir = CONFIG["raw_data_dir"]
    cleaned_data_base_dir = CONFIG["cleaned_data_dir"]
    
    # Output path logic
    # Special handling for "ablation_exp_v2" suffix to match original script's behavior if needed.
    # This path seems to go outside the supplementary_materials structure.
    # Original: /mnt/yscfs/dumingxuan/deepresearch_bench/abaltion_study/combined_results/ablation_exp_v2
    # New script in: /mnt/yscfs/dumingxuan/deepresearch_bench/supplementary_materials/ablation_study/
    # Relative path: ../../abaltion_study/combined_results/ablation_exp_v2 (note "abaltion" typo)
    if current_experiment_suffix == "ablation_exp_v2" or args.exp_suffix == "ablation_exp_v2":
        # This path structure assumes the original "abaltion_study" directory exists at the same level as "supplementary_materials" parent.
        # Adjust if your "combined_results" for v2 is elsewhere.
        # Using 'ablation_study' consistently instead of 'abaltion_study' for the new path.
        custom_results_base_dir = os.path.join("..", "..", "ablation_study", "combined_results", "ablation_exp_v2")
        os.makedirs(custom_results_base_dir, exist_ok=True)
        output_results_file = os.path.join(custom_results_base_dir, f"{target_model_name}.jsonl")
    else:
        # Standard results path: supplementary_materials/ablation_study/results/{target_model}/{experiment_suffix}/scores.jsonl
        # CONFIG["results_dir"] is "./results"
        results_output_dir = os.path.join(CONFIG["results_dir"], target_model_name, current_experiment_suffix)
        os.makedirs(results_output_dir, exist_ok=True)
        output_results_file = os.path.join(results_output_dir, f"{target_model_name}_scores.jsonl")

    llm_scoring_model = CONFIG["llm_api_model"] # Model for the LLM judge
    max_processing_workers = CONFIG["max_workers"]
    max_api_retries = CONFIG["max_retries"]
    processing_limit = args.limit
    skip_article_cleaning = args.skip_cleaning

    logging.info("--- Experiment Configuration ---")
    logging.info(f"Target Model: {target_model_name}")
    logging.info(f"Experiment Mode: {'Vanilla Prompt' if use_vanilla_prompt else 'Comparative/Pointwise Scoring'}")
    if not use_vanilla_prompt:
        logging.info(f"  Use Dynamic Criteria: {use_dynamic_criteria}")
        logging.info(f"  Use Reference Article: {use_reference}")
        if use_reference:
            logging.info(f"    Reference Model: {reference_model_name}")
            logging.info(f"    Reference Article First: {reference_first}")
        logging.info(f"  Use Weights: {use_weights}")
    logging.info(f"Processing Limit: {processing_limit if processing_limit else 'None'}")
    logging.info(f"Skip Cleaning: {skip_article_cleaning}")
    logging.info(f"Experiment Suffix: {current_experiment_suffix}")
    logging.info(f"Output File: {output_results_file}")
    logging.info(f"Criteria File Used: {criteria_file_path}")
    logging.info("----------------------------")

    # Initialize LLM client (AIClient, no model in constructor typically)
    # The AIClient might be configured via environment variables or a global config.
    llm_judge_client = AIClient() 
    logging.info(f"Initialized LLM client for scoring (using model: {llm_scoring_model} based on AIClient configuration).")
    
    # 1. Ensure articles are clean (target model only)
    if not skip_article_cleaning:
        logging.info("Checking and cleaning articles if necessary...")
        # Pass llm_judge_client also as clean_agent
        cleaning_successful = clean_articles(
            raw_data_base_dir, 
            cleaned_data_base_dir, 
            target_model_name, 
            clean_agent=llm_judge_client, # Using the same client for cleaning
            max_workers=max_processing_workers,
            max_retries=max_api_retries,
            limit=processing_limit # Apply limit to cleaning as well
        )
        
        if not cleaning_successful:
            logging.error("Article cleaning process reported issues. Exiting scoring.")
            return
        # Note: Reference model articles (if used) are assumed to be pre-cleaned and available.
    else:
        logging.info("Skipping article cleaning step as per arguments.")

    # 2. Load data for scoring
    logging.info("Loading data for scoring...")
    all_tasks_data = load_jsonl(query_file_path)
    all_criteria_data = load_jsonl(criteria_file_path) # Loaded for weights or dynamic content
    
    target_model_articles_path = os.path.join(cleaned_data_base_dir, f"{target_model_name}.jsonl")
    all_target_model_articles = load_jsonl(target_model_articles_path)
    if not all_target_model_articles:
        logging.error(f"No cleaned articles found for target model {target_model_name} at {target_model_articles_path}. Exiting.")
        return

    all_reference_model_articles = []
    if not use_vanilla_prompt and use_reference and reference_model_name:
        reference_model_articles_path = os.path.join(cleaned_data_base_dir, f"{reference_model_name}.jsonl")
        all_reference_model_articles = load_jsonl(reference_model_articles_path)
        if not all_reference_model_articles:
            logging.error(f"No cleaned articles found for reference model {reference_model_name} at {reference_model_articles_path}, but reference articles are required. Exiting.")
            return
    
    # Apply limit if specified
    tasks_to_run = all_tasks_data
    if processing_limit is not None and processing_limit > 0:
        tasks_to_run = all_tasks_data[:processing_limit]
        logging.info(f"Applied limit: using first {len(tasks_to_run)} tasks for scoring.")

    if not tasks_to_run:
        logging.error("No tasks to process after applying limit or initial load. Exiting.")
        return
    
    # Create lookup maps
    task_prompts_set = {task['prompt'] for task in tasks_to_run if 'prompt' in task}
    
    # Filter related data based on selected task prompts
    criteria_data_map = {item['prompt']: item for item in all_criteria_data if item.get('prompt') in task_prompts_set}
    target_articles_map = {item['prompt']: item for item in all_target_model_articles if item.get('prompt') in task_prompts_set}
    reference_articles_map = {}
    if not use_vanilla_prompt and use_reference:
        reference_articles_map = {item['prompt']: item for item in all_reference_model_articles if item.get('prompt') in task_prompts_set}

    # Further filter tasks to ensure all required data is present for each task
    valid_tasks_for_processing = []
    for task in tasks_to_run:
        prompt = task.get('prompt')
        if not prompt: continue

        if prompt not in target_articles_map:
            logging.warning(f"Skipping task ID {task.get('id','N/A')} (Prompt: {prompt[:30]}...) as target article is missing.")
            continue
        if not use_vanilla_prompt:
            if use_reference and prompt not in reference_articles_map:
                logging.warning(f"Skipping task ID {task.get('id','N/A')} (Prompt: {prompt[:30]}...) as reference article is missing for comparative scoring.")
                continue
            if (use_dynamic_criteria or use_weights) and prompt not in criteria_data_map:
                 logging.warning(f"Skipping task ID {task.get('id','N/A')} (Prompt: {prompt[:30]}...) as criteria data is missing for dynamic/weighted scoring.")
                 continue
        valid_tasks_for_processing.append(task)
    
    if not valid_tasks_for_processing:
        logging.error("No tasks remain after filtering for available data. Check data files and paths. Exiting.")
        return

    logging.info(f"Processing {len(valid_tasks_for_processing)} tasks after data validation.")
    logging.info(f"Found {len(target_articles_map)} relevant target articles.")
    if not use_vanilla_prompt and use_reference:
        logging.info(f"Found {len(reference_articles_map)} relevant reference articles.")
    if not use_vanilla_prompt and (use_dynamic_criteria or use_weights):
        logging.info(f"Found {len(criteria_data_map)} relevant criteria sets.")
    
    # Check for existing results to avoid re-processing
    processed_task_ids = set()
    final_results_list = [] # This will hold both old and new results

    if os.path.exists(output_results_file):
        logging.info(f"Loading existing results from {output_results_file} to avoid re-processing.")
        existing_results_data = load_jsonl(output_results_file)
        for item in existing_results_data:
            # Ensure the existing result matches current experimental flags to be considered "processed" for this run
            # This is a simplified check; a more robust check would compare all relevant args.
            # For now, just check if 'use_vanilla_prompt' matches.
            if item.get('id') is not None and item.get('use_vanilla_prompt') == use_vanilla_prompt:
                 processed_task_ids.add(item['id'])
                 final_results_list.append(item) # Add to list to preserve them
        logging.info(f"Loaded {len(processed_task_ids)} processed task IDs from existing results file matching current vanilla_prompt mode.")

    tasks_needing_processing_now = [task for task in valid_tasks_for_processing if task.get('id') not in processed_task_ids]

    if not tasks_needing_processing_now:
        logging.info("All tasks for the current configuration have already been processed according to the existing results file.")
        # Re-save sorted results if any were loaded, to ensure order.
        if final_results_list:
            final_results_list.sort(key=lambda x: x.get('id', float('inf')))
            with open(output_results_file, 'w', encoding='utf-8') as f:
                for res_item in final_results_list:
                    f.write(json.dumps(res_item, ensure_ascii=False) + '\n')
            logging.info(f"Existing results re-saved to {output_results_file} with consistent sorting.")
        return

    logging.info(f"Starting to process {len(tasks_needing_processing_now)} new tasks with {max_processing_workers} workers.")

    thread_lock = threading.Lock()
    newly_processed_results = []
    
    start_time = time.time()
    with tqdm(total=len(tasks_needing_processing_now), desc=f"Scoring {target_model_name} ({current_experiment_suffix})") as pbar_instance:
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_processing_workers) as executor:
            future_tasks = [
                executor.submit(
                    process_single_item,
                    task_item,
                    target_articles_map,
                    reference_articles_map, # Will be empty if not use_reference
                    criteria_data_map,      # Will be empty if not (dynamic or weights)
                    llm_judge_client,
                    thread_lock,
                    pbar_instance,
                    max_api_retries,
                    # Pass experiment flags
                    use_dynamic_criteria,
                    use_reference,
                    use_weights,
                    reference_first,
                    use_vanilla_prompt
                )
                for task_item in tasks_needing_processing_now
            ]

            for future in concurrent.futures.as_completed(future_tasks):
                result = future.result()
                if result: # process_single_item always returns a dict
                    newly_processed_results.append(result)
    
    end_time = time.time()
    processing_duration = end_time - start_time

    # Combine new results with previously loaded valid results
    final_results_list.extend(newly_processed_results)
    # Sort all results by ID before saving
    final_results_list.sort(key=lambda x: x.get('id', float('inf')))
    
    # Filter out items with errors before final save, or save them with error field.
    # Current process_single_item returns error key, so they are included.
    # We can choose to write only successful ones. For now, writing all.

    logging.info(f"Processing complete. Saving {len(final_results_list)} total results to {output_results_file}...")
    try:
        with open(output_results_file, 'w', encoding='utf-8') as f:
            for res_item in final_results_list:
                f.write(json.dumps(res_item, ensure_ascii=False) + '\n')
        logging.info("Results saved successfully.")
    except IOError as e:
        logging.error(f"Failed to write results to {output_results_file}: {e}")

    # --- Run Summary ---
    successful_new_count = sum(1 for res in newly_processed_results if "error" not in res)
    failed_new_count = len(newly_processed_results) - successful_new_count
    
    logging.info("--- Run Summary ---")
    logging.info(f"Experiment Suffix: {current_experiment_suffix}")
    logging.info(f"Target Model: {target_model_name}")
    if not use_vanilla_prompt and use_reference:
        logging.info(f"Reference Model: {reference_model_name}")
    logging.info(f"Scoring LLM: (Configured in AIClient, e.g. {CONFIG['llm_api_model']})")
    logging.info(f"Tasks Processed in this run: {successful_new_count} successful, {failed_new_count} failed.")
    logging.info(f"Total processing time for new tasks: {processing_duration:.2f} seconds.")
    if successful_new_count > 0:
        avg_time_per_task = processing_duration / successful_new_count
        logging.info(f"Average time per newly successful task: {avg_time_per_task:.2f} seconds.")
    logging.info(f"Total results in file (old+new): {len(final_results_list)}")
    logging.info(f"Results file: {output_results_file}")
    logging.info("-------------------")

if __name__ == "__main__":
    main() 