import json
import os
import threading
import concurrent.futures
import argparse
from tqdm import tqdm
import logging
import time
import re 
from utils.api import AIClient
import glob

# Import scoring prompts for Chinese and English
from prompt.score_prompt_zh import generate_merged_score_prompt as zh_merged_score_prompt
from prompt.score_prompt_en import generate_merged_score_prompt as en_merged_score_prompt
from utils.score_calculator import calculate_weighted_scores
from utils.json_extractor import extract_json_from_markdown
from prompt.clean_prompt import clean_article_prompt_zh as clean_article_prompt
from utils.clean_article import clean_model, clean_single_article

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Fixed configuration parameters with relative paths
CONFIG = {
    "query_file": "data/prompt_data/query.jsonl",
    "criteria_file": "data/criteria_data/criteria.jsonl",
    "raw_data_dir": "data/test_data/raw_data",
    "cleaned_data_dir": "data/test_data/cleaned_data",
    "reference_file": "data/test_data/cleaned_data/reference.jsonl",
    "results_dir": "results",
    "max_workers": 5,
    "max_retries": 3
}

def load_jsonl(file_path):
    """Load data from JSONL file"""
    data = []
    if not os.path.exists(file_path):
        logging.error(f"File not found: {file_path}")
        raise FileNotFoundError(f"Required file not found: {file_path}")
    
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError as e:
                    logging.warning(f"Error parsing JSON in {file_path}: {e}, line content: {line.strip()}")
    
    if not data:
        logging.error(f"No valid data found in file: {file_path}")
        raise ValueError(f"File contains no valid data: {file_path}")
        
    return data

def clean_articles(raw_data_dir, cleaned_data_dir, target_model, clean_agent, 
                   max_workers=CONFIG["max_workers"], max_retries=CONFIG["max_retries"], 
                   limit=None):
    """
    Clean articles for the target model if they haven't been cleaned yet.
    
    Args:
        raw_data_dir: Directory containing raw article data
        cleaned_data_dir: Directory to store cleaned articles
        target_model: Target model to evaluate
        clean_agent: API client for cleaning
        max_workers: Maximum number of threads for cleaning
        max_retries: Maximum number of retries on cleaning failure
        limit: Limit on number of articles to process (optional)
        
    Returns:
        bool: True if cleaning completed successfully
    """
    if not clean_agent:
        raise ValueError("clean_agent not provided, cannot clean articles")
    
    # Check if raw data file exists
    raw_file = os.path.join(raw_data_dir, f"{target_model}.jsonl")
    if not os.path.exists(raw_file):
        raise FileNotFoundError(f"Raw data file for model {target_model} not found: {raw_file}")
    
    # Load raw data
    raw_data = load_jsonl(raw_file)
    
    # Apply quantity limit if specified
    if limit is not None and limit > 0:
        raw_data = raw_data[:limit]
        if not raw_data:
            raise ValueError(f"No data left after applying limit for model {target_model}")
    
    # Extract all IDs that need processing
    raw_ids = {item.get('id') for item in raw_data if 'id' in item}
    if not raw_ids:
        raise ValueError(f"No valid IDs found in raw data for model {target_model}")
    
    # Check if cleaned data file exists
    cleaned_file = os.path.join(cleaned_data_dir, f"{target_model}.jsonl")
    
    # Determine which items need cleaning
    items_to_clean = raw_data
    if os.path.exists(cleaned_file):
        # Load cleaned data to check which IDs are already processed
        cleaned_data = load_jsonl(cleaned_file)
        cleaned_ids = {item.get('id') for item in cleaned_data if 'id' in item}
        
        # Only clean what's missing
        missing_ids = raw_ids - cleaned_ids
        if not missing_ids:
            logging.info(f"All articles for model {target_model} already cleaned")
            return True
            
        items_to_clean = [item for item in raw_data if 'id' in item and item['id'] in missing_ids]
        logging.info(f"Need to clean {len(missing_ids)} articles for model {target_model}")
    else:
        logging.info(f"Need to clean all {len(raw_ids)} articles for model {target_model}")
    
    # Clean articles
    logging.info(f"Starting to clean {len(items_to_clean)} articles for model {target_model}...")
    
    # Create progress bar
    cleaned_results = []
    with tqdm(total=len(items_to_clean), desc=f"Cleaning {target_model}") as pbar:
        # Load query data for language detection
        query_file = CONFIG["query_file"]
        all_queries = load_jsonl(query_file)
        prompt_to_language = {query.get('prompt'): query.get('language') 
                              for query in all_queries}
        
        # Define function to process a single article
        def process_article(item):
            try:
                # Determine current language based on task associated with prompt
                current_language = "en"  # Default to English
                item_prompt = item.get('prompt', '')
                if item_prompt in prompt_to_language:
                    current_language = prompt_to_language[item_prompt]
                
                cleaned_article = clean_single_article(
                    item, clean_agent, max_retries=max_retries, language=current_language
                )
                return cleaned_article
            except Exception as e:
                logging.error(f"Error cleaning article ID {item.get('id')}: {e}")
                return None
            finally:
                pbar.update(1)
        
        # Use thread pool for parallel processing
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(process_article, item) for item in items_to_clean]
            for future in concurrent.futures.as_completed(futures):
                result = future.result()
                if result:
                    cleaned_results.append(result)
    
    # Ensure output directory exists
    os.makedirs(os.path.dirname(cleaned_file), exist_ok=True)
    
    # Merge with existing cleaned data if needed
    if os.path.exists(cleaned_file):
        existing_data = load_jsonl(cleaned_file)
        existing_ids = {item.get('id') for item in existing_data if 'id' in item}
        
        # Merge existing and new results
        merged_results = existing_data.copy()
        for item in cleaned_results:
            if 'id' in item and item['id'] not in existing_ids:
                merged_results.append(item)
                existing_ids.add(item['id'])
        
        # Sort results by ID
        merged_results.sort(key=lambda x: x.get('id', float('inf')))
        
        # Write merged results
        with open(cleaned_file, 'w', encoding='utf-8') as f:
            for item in merged_results:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        
        logging.info(f"Saved {len(cleaned_results)} newly cleaned articles to {cleaned_file}, "
                     f"total {len(merged_results)}")
    else:
        # Write new results
        with open(cleaned_file, 'w', encoding='utf-8') as f:
            for item in cleaned_results:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        
        logging.info(f"Created file {cleaned_file} with {len(cleaned_results)} cleaned articles")
    
    # Verify all items were cleaned successfully
    if len(cleaned_results) < len(items_to_clean):
        logging.error(f"Failed to clean all articles: expected {len(items_to_clean)}, "
                      f"but cleaned only {len(cleaned_results)}")
        return False
    
    return True

def format_criteria_list(criteria_data):
    """Format evaluation criteria list as JSON string, without weight information"""
    criteria_for_prompt = {}
    criterions_dict = criteria_data.get("criterions", {})

    if not criterions_dict:
        raise ValueError(f"'criterions' not found in criteria_data with ID {criteria_data.get('id', 'unknown')}")

    for dim, criterions_list in criterions_dict.items():
        if not isinstance(criterions_list, list):
            logging.warning(f"Value for dimension '{dim}' is not a list. Skipping.")
            continue

        criteria_for_prompt[dim] = []
        for crit_item in criterions_list:
            if isinstance(crit_item, dict) and "criterion" in crit_item and "explanation" in crit_item:
                criteria_for_prompt[dim].append({
                    "criterion": crit_item["criterion"],
                    "explanation": crit_item["explanation"]
                })
            else:
                logging.warning(f"Invalid criteria format in dimension '{dim}'. Skipping item.")

    if not any(criteria_for_prompt.values()):
        raise ValueError("No valid criteria found in any dimension")

    # Convert dictionary to JSON string
    try:
        return json.dumps(criteria_for_prompt, ensure_ascii=False, indent=2)
    except TypeError as e:
        raise ValueError(f"Failed to serialize criteria to JSON: {e}")

def process_single_item(task_data, target_articles_map, reference_articles_map, criteria_map, 
                        llm_client, lock, pbar, max_retries=3, 
                        llm_model_name="", language="en"):
    """Process a single task: get data, call LLM, parse results, calculate scores"""
    task_id = task_data.get('id')
    prompt = task_data.get('prompt')

    # Data retrieval and validation
    if prompt not in target_articles_map:
        logging.error(f"Target article not found for ID {task_id}")
        with lock: pbar.update(1)
        return {"id": task_id, "prompt": prompt, "error": "Target article not found"}
    
    if prompt not in reference_articles_map:
        logging.error(f"Reference article not found for ID {task_id}")
        with lock: pbar.update(1)
        return {"id": task_id, "prompt": prompt, "error": "Reference article not found"}
    
    if prompt not in criteria_map:
        logging.error(f"Evaluation criteria not found for ID {task_id}")
        with lock: pbar.update(1)
        return {"id": task_id, "prompt": prompt, "error": "Evaluation criteria not found"}

    target_article_data = target_articles_map[prompt]
    reference_article_data = reference_articles_map[prompt]
    criteria_data = criteria_map[prompt]

    target_article = target_article_data.get('article', '')
    reference_article = reference_article_data.get('article', '')

    # Format evaluation criteria list in JSON
    try:
        criteria_list_str = format_criteria_list(criteria_data)
    except ValueError as e:
        logging.error(f"ID {task_id}: {str(e)}")
        with lock: pbar.update(1)
        return {"id": task_id, "prompt": prompt, "error": f"Failed to format criteria: {str(e)}"}

    # Choose scoring prompt based on language
    merged_score_prompt = zh_merged_score_prompt if language == "zh" else en_merged_score_prompt
    
    # Prepare LLM prompt
    user_prompt = merged_score_prompt.format(
        task_prompt=prompt,
        article_1=target_article,
        article_2=reference_article,
        criteria_list=criteria_list_str 
    )

    # LLM call and response processing
    llm_response_str = None
    llm_output_json = None
    success = False
    retry_count = 0

    while retry_count < max_retries and not success:
        try:
            # Call LLM
            llm_response_str = llm_client.generate(
                user_prompt=user_prompt,
                system_prompt=""
            )

            # Extract JSON from response
            json_str_extracted = extract_json_from_markdown(llm_response_str)
            if not json_str_extracted:
                raise ValueError("Failed to extract JSON from LLM response")
                
            llm_output_json = json.loads(json_str_extracted)
            
            # Check if all required dimensions exist
            expected_dims = ["comprehensiveness", "insight", "instruction_following", "readability"]
            if not all(dim in llm_output_json for dim in expected_dims):
                missing_dims = [dim for dim in expected_dims if dim not in llm_output_json]
                raise ValueError(f"Missing expected dimensions: {missing_dims}")
            
            # All checks passed
            success = True
            
        except Exception as e:
            retry_count += 1
            if retry_count < max_retries:
                logging.warning(f"ID {task_id}: Retry {retry_count}/{max_retries} - {str(e)}")
                time.sleep(1.5 ** retry_count)
            else:
                logging.error(f"ID {task_id}: Failed after {max_retries} retries - {str(e)}")
    
    # Handle final failure
    if not success:
        with lock: pbar.update(1)
        return {
            "id": task_id,
            "prompt": prompt,
            "error": f"Failed to get valid response after {max_retries} retries",
            "model_output": llm_response_str[:500] if llm_response_str else "No response"
        }

    # Calculate weighted scores
    try:
        scores = calculate_weighted_scores(llm_output_json, criteria_data, language)
        
        # Calculate overall score = target / (target + reference)
        target_total = scores["target"]["total"]
        reference_total = scores["reference"]["total"]
        overall_score = 0
        if target_total + reference_total > 0:
            overall_score = target_total / (target_total + reference_total)
        
    except Exception as e:
        logging.error(f"ID {task_id}: Error calculating scores - {str(e)}")
        with lock: pbar.update(1)
        return {
            "id": task_id,
            "prompt": prompt,
            "error": f"Error calculating scores: {str(e)}",
            "model_output": llm_response_str[:500] if llm_response_str else "No response"
        }

    # Prepare final result
    final_result = {
        "id": task_id,
        "prompt": prompt,
        "model_output": llm_response_str,
        **{f"target_{k}": v for k, v in scores["target"]["dims"].items()},
        "target_total_weighted_avg": scores["target"]["total"],
        **{f"reference_{k}": v for k, v in scores["reference"]["dims"].items()},
        "reference_total_weighted_avg": scores["reference"]["total"],
        "overall_score": overall_score
    }

    with lock:
        pbar.update(1)

    return final_result

def process_language_data(language, target_model, llm_client, clean_agent, limit=None):
    """Process data for a single language (Chinese or English)"""
    # Load configuration
    query_file = CONFIG["query_file"]
    criteria_file = CONFIG["criteria_file"]
    raw_data_dir = CONFIG["raw_data_dir"]
    cleaned_data_dir = CONFIG["cleaned_data_dir"]
    reference_file = CONFIG["reference_file"]
    
    # Step 1: Clean target model articles if needed
    logging.info(f"Checking if {target_model} articles need cleaning...")
    cleaning_success = clean_articles(
        raw_data_dir, 
        cleaned_data_dir, 
        target_model, 
        clean_agent, 
        CONFIG["max_workers"], 
        CONFIG["max_retries"],
        limit
    )
    
    if not cleaning_success:
        logging.error(f"Article cleaning failed for {target_model}, cannot continue.")
        return None
    
    # Step 2: Load data for scoring
    logging.info(f"Loading {language} data for evaluation...")
    
    try:
        # Load tasks (queries)
        all_tasks = load_jsonl(query_file)
        # Filter tasks by language
        all_tasks = [task for task in all_tasks if task.get('language') == language]
        if not all_tasks:
            logging.error(f"No {language} tasks found in query file")
            return None
            
        # Apply limit if specified
        if limit is not None and limit > 0:
            all_tasks = all_tasks[:limit]
            
        # Get prompts from tasks
        task_prompts = {task['prompt'] for task in all_tasks if 'prompt' in task}
        
        # Load criteria data
        all_criteria = load_jsonl(criteria_file)
        criteria_list = [c for c in all_criteria if c.get('prompt') in task_prompts]
        if not criteria_list:
            logging.error(f"No evaluation criteria found for {language} tasks")
            return None
            
        # Load target model articles
        target_file = os.path.join(cleaned_data_dir, f"{target_model}.jsonl")
        all_target_articles = load_jsonl(target_file)
        target_articles_list = [a for a in all_target_articles if a.get('prompt') in task_prompts]
        if not target_articles_list:
            logging.error(f"No target articles found for model {target_model} in {language}")
            return None
            
        # Load reference articles
        all_reference_articles = load_jsonl(reference_file)
        reference_articles_list = [a for a in all_reference_articles if a.get('prompt') in task_prompts]
        if not reference_articles_list:
            logging.error(f"No reference articles found for {language} tasks")
            return None
            
        # Build mappings
        criteria_map = {item['prompt']: item for item in criteria_list}
        target_articles_map = {item['prompt']: item for item in target_articles_list}
        reference_articles_map = {item['prompt']: item for item in reference_articles_list}
        
        # Check for missing data
        for task in all_tasks:
            prompt = task.get('prompt')
            if prompt not in criteria_map:
                logging.warning(f"No criteria found for task prompt: {prompt[:50]}...")
            if prompt not in target_articles_map:
                logging.warning(f"No target article found for task prompt: {prompt[:50]}...")
            if prompt not in reference_articles_map:
                logging.warning(f"No reference article found for task prompt: {prompt[:50]}...")
                
        # Filter out tasks with missing data
        tasks_to_process = [task for task in all_tasks 
                           if task.get('prompt') in criteria_map
                           and task.get('prompt') in target_articles_map
                           and task.get('prompt') in reference_articles_map]
        
        if not tasks_to_process:
            logging.error(f"No complete task data found for {language}")
            return None
            
        logging.info(f"Processing {len(tasks_to_process)} {language} tasks...")
        
    except Exception as e:
        logging.error(f"Error loading data: {str(e)}")
        return None
    
    # Step 3: Process each task and generate scores
    lock = threading.Lock()
    results_list = []
    
    with tqdm(total=len(tasks_to_process), desc=f"Scoring {language} {target_model}") as pbar:
        with concurrent.futures.ThreadPoolExecutor(max_workers=CONFIG["max_workers"]) as executor:
            futures = [
                executor.submit(
                    process_single_item,
                    task,
                    target_articles_map,
                    reference_articles_map,
                    criteria_map,
                    llm_client,
                    lock,
                    pbar,
                    CONFIG["max_retries"],
                    language
                )
                for task in tasks_to_process
            ]
            
            for future in concurrent.futures.as_completed(futures):
                result = future.result()
                if result:
                    results_list.append(result)
    
    successful_results = [res for res in results_list if "error" not in res]
    
    logging.info(f"{language} evaluation complete. Successfully scored {len(successful_results)} "
                 f"out of {len(tasks_to_process)} tasks.")
    
    return successful_results

def main():
    parser = argparse.ArgumentParser(description='Score model articles against reference articles using detailed evaluation criteria and LLM.')
    parser.add_argument('target_model', type=str, help='Name of target model to evaluate (e.g., claude-3-7-sonnet-20250219).')
    parser.add_argument('--limit', type=int, default=None, help='Limit on number of prompts to process (for testing).')
    parser.add_argument('--skip_cleaning', action='store_true', help='Skip article cleaning step.')
    parser.add_argument('--only_zh', action='store_true', help='Only process Chinese data.')
    parser.add_argument('--only_en', action='store_true', help='Only process English data.')

    args = parser.parse_args()
    
    # Use fixed configuration and command line parameters
    target_model = args.target_model
    results_dir = CONFIG["results_dir"]
    limit = args.limit
    skip_cleaning = args.skip_cleaning
    
    # Ensure results directory exists
    os.makedirs(results_dir, exist_ok=True)
    
    # Initialize LLM client
    llm_client = AIClient()
    clean_agent = llm_client  # Use same client instance for cleaning
    
    all_results = []
    
    # Process Chinese data (unless only English specified)
    if not args.only_en:
        logging.info("Starting Chinese data processing...")
        if not skip_cleaning:
            zh_results = process_language_data("zh", target_model, llm_client, clean_agent, limit)
            if zh_results:
                all_results.extend(zh_results)
        else:
            logging.info("Skipping article cleaning step for Chinese data.")
    
    # Process English data (unless only Chinese specified)
    if not args.only_zh:
        logging.info("Starting English data processing...")
        if not skip_cleaning:
            en_results = process_language_data("en", target_model, llm_client, clean_agent, limit)
            if en_results:
                all_results.extend(en_results)
        else:
            logging.info("Skipping article cleaning step for English data.")
    
    # Output results to file
    output_file = os.path.join(results_dir, f"{target_model}.jsonl")
    
    if all_results:
        # Sort results by ID
        all_results.sort(key=lambda x: x.get('id', float('inf')))
        
        logging.info(f"Saving {len(all_results)} results to {output_file}...")
        try:
            with open(output_file, 'w', encoding='utf-8') as f:
                for result in all_results:
                    f.write(json.dumps(result, ensure_ascii=False) + '\n')
            logging.info("Results saved successfully.")
        except IOError as e:
            logging.error(f"Failed to write results to {output_file}: {e}")
    else:
        logging.warning("No results to save.")
    
    logging.info("--- Run Summary ---")
    logging.info(f"Target model: {target_model}")
    logging.info(f"Total tasks successfully processed: {len(all_results)}")
    logging.info(f"Results file: {output_file}")
    logging.info("-------------------")

if __name__ == "__main__":
    main() 