import json
import os
import threading
import concurrent.futures
from tqdm import tqdm
from .api import AIClient
from prompt.clean_prompt import clean_article_prompt_zh, clean_article_prompt_en
import textwrap

def chunk_clean_article(article, clean_agent, max_retries=3, language="zh"):
    """
    Split long articles into smaller chunks for processing, then combine results
    Uses a progressive chunking strategy: first splits into 2 chunks, if that fails try 4, etc.
    
    Args:
        article: The article to clean
        clean_agent: LLM API instance
        max_retries: Maximum chunking attempts
        language: Article language, "zh" for Chinese, "en" for English
        
    Returns:
        cleaned_article: The cleaned article
    """
    # Choose cleaning prompt template based on language
    clean_prompt = clean_article_prompt_zh if language == "zh" else clean_article_prompt_en
    
    # Try different chunk counts
    for attempt in range(max_retries):
        # Calculate chunk count for this attempt
        num_chunks = 2 ** (attempt + 1)  # First attempt 2 chunks, then 4, then 8...
        print(f"Attempting to process article in {num_chunks} chunks...")
        
        # Split article into chunks
        chunks = []
        chunk_size = len(article) // num_chunks
        
        for i in range(num_chunks):
            start = i * chunk_size
            # Last chunk goes to end of article
            end = len(article) if i == num_chunks - 1 else (i + 1) * chunk_size
            
            # Split at sentence boundaries (except for last chunk)
            if i < num_chunks - 1:
                # Look for sentence boundaries near chunk_size
                search_start = max(0, end - 200)
                search_end = min(len(article), end + 200)
                
                for j in range(end, search_start, -1):
                    if j < len(article) and article[j] in ['.', '?', '!', '。', '？', '！', '\n']:
                        end = j + 1
                        break
            
            chunk = article[start:end]
            chunks.append((start, end, chunk))
        
        # Clean each chunk
        success = True
        cleaned_chunks = []
        
        # Create progress bar for chunk processing
        with tqdm(total=len(chunks), desc=f"Cleaning article ({num_chunks} chunks)") as chunk_pbar:
            for i, (start, end, chunk) in enumerate(chunks):
                try:
                    print(f"Processing chunk {i+1}/{len(chunks)} (characters {start}-{end}, length: {len(chunk)})")
                    user_prompt = clean_prompt.format(article=chunk)
                    cleaned_chunk = clean_agent.generate(user_prompt=user_prompt, system_prompt="")
                    
                    # Check cleaning result
                    if not cleaned_chunk or len(cleaned_chunk.strip()) < 50:
                        print(f"Chunk {i+1} cleaning result too short or empty, using original chunk")
                        cleaned_chunks.append(chunk)
                    else:
                        cleaned_chunks.append(cleaned_chunk)
                    
                    chunk_pbar.update(1)
                    
                except Exception as e:
                    print(f"Chunk {i+1}/{len(chunks)} cleaning failed: {e}")
                    
                    # Check if this is a token limit error
                    if "tokens" in str(e).lower() and "less than" in str(e).lower():
                        print(f"Chunk {i+1} still too long, need further chunking")
                        success = False
                        break  # Current chunking strategy failed, try more chunks
                    
                    # Other errors, use original chunk
                    cleaned_chunks.append(chunk)
                    chunk_pbar.update(1)
        
        # If all chunks processed successfully, merge results
        if success:
            print(f"All {len(chunks)} chunks processed, merging results")
            merged_article = "".join(cleaned_chunks)
            return merged_article
    
    # If all attempts failed, return original article
    print(f"Warning: After {max_retries} chunking attempts, still unable to process article, returning original")
    return article

def clean_single_article(item, clean_agent, output_file=None, processed_ids=None, file_lock=None, pbar_lock=None, pbar=None, max_retries=10, language="zh"):
    """
    Clean a single article, removing citation formats, references, etc.
    
    If output_file, processed_ids etc. are provided, results are written directly to file
    Otherwise returns the cleaned item object
    
    Args:
        item: Dictionary containing id, prompt and article
        clean_agent: LLM API instance
        output_file: Output file path (optional)
        processed_ids: Set of processed IDs (optional)
        file_lock: File writing lock (optional)
        pbar_lock: Progress bar update lock (optional)
        pbar: Progress bar object (optional)
        max_retries: Maximum retry attempts
        language: Article language, "zh" for Chinese, "en" for English
        
    Returns:
        dict: Cleaned item dictionary (if only essential parameters provided)
        str: Processed item_id (if file writing parameters also provided)
        None: Processing failed
    """
    # Check if clean_agent is provided
    if not clean_agent:
        print(f"Error: clean_agent not provided, cannot clean article")
        if pbar and pbar_lock:
            with pbar_lock:
                pbar.update(1)
        return None
        
    try:
        data = item.copy()
        item_id = data.get('id')
        prompt = data.get('prompt', '')
        article = data.get('article', '')
        
        # Choose cleaning prompt template based on language
        clean_prompt = clean_article_prompt_zh if language == "zh" else clean_article_prompt_en
        
        # Skip if missing required fields or already processed
        if not item_id or not prompt or not article:
            if pbar and pbar_lock:
                with pbar_lock:
                    pbar.update(1)
            return None
        
        if processed_ids is not None and item_id in processed_ids:
            if pbar and pbar_lock:
                with pbar_lock:
                    pbar.update(1)
            return None
        
        # Initialize retry counter
        retry_count = 0
        cleaned_article = ""
        
        # Retry mechanism
        while (not cleaned_article or len(cleaned_article.strip()) < 50) and retry_count < max_retries:
            if retry_count > 0:
                print(f"ID: {item_id} - Cleaning result empty or too short, retry #{retry_count}")
            
            try:
                # Call API to clean article       
                user_prompt = clean_prompt.format(article=article)
                cleaned_article = clean_agent.generate(user_prompt=user_prompt, system_prompt="")
                
            except Exception as api_error:
                print(f"ID: {item_id} - API call error: {api_error}")
                # Check if this is a token limit error
                if "tokens" in str(api_error).lower() and "less than" in str(api_error).lower():
                    print(f"ID: {item_id} - Article too long, trying chunk processing")
                    cleaned_article = chunk_clean_article(article, clean_agent, language=language)
                    
                    if cleaned_article and len(cleaned_article.strip()) >= 50:
                        break
                    else:
                        print(f"ID: {item_id} - Chunk processing failed, using original article")
                        cleaned_article = article
                        break
                # Other errors continue retrying
            
            retry_count += 1
        
        # If still empty after cleaning, use original article
        if not cleaned_article or len(cleaned_article.strip()) < 50:
            print(f"Warning: ID: {item_id} - Cleaned article empty or too short after {max_retries} retries, using original article")
            cleaned_article = article
        
        # Build output data
        result = {
            "id": item_id,
            "prompt": prompt,
            "article": cleaned_article
        }
        
        # If output parameters provided, write to file
        if output_file and file_lock and processed_ids is not None:
            # Use file lock to write results in append mode
            with file_lock:
                with open(output_file, 'a', encoding='utf-8') as f:
                    f.write(json.dumps(result, ensure_ascii=False) + '\n')
                # Mark as processed
                processed_ids.add(item_id)
            
            if pbar and pbar_lock:
                with pbar_lock:
                    pbar.update(1)
                    
            return item_id
        else:
            # Simple mode, just return result object
            if pbar and pbar_lock:
                with pbar_lock:
                    pbar.update(1)
            return result
        
    except Exception as e:
        if pbar and pbar_lock:
            with pbar_lock:
                pbar.update(1)
        print(f"Error cleaning article {item.get('id', 'unknown')}: {e}")
        return None

def clean_articles(input_file, output_file, clean_agent, max_workers=5, max_retries=10, limit=None, language="zh"):
    """
    Read data from input_file, clean article content, and write to output_file
    
    Args:
        input_file: Input jsonl file path
        output_file: Output jsonl file path
        clean_agent: LLM API instance
        max_workers: Maximum thread count
        max_retries: Maximum retry attempts
        limit: Limit on number of items to process
        language: Article language, "zh" for Chinese, "en" for English
    """
    # Load input data
    print(f"Loading data from {input_file}...")
    with open(input_file, 'r', encoding='utf-8') as f:
        input_data = [json.loads(line) for line in f if line.strip()]
    
    # Apply limit (if specified)
    if limit is not None and limit > 0:
        print(f"Applying limit: processing only first {limit} items")
        input_data = input_data[:limit]
        
    # Set of processed IDs for deduplication
    processed_ids = set()
    
    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # If output file exists, read already processed IDs
    if os.path.exists(output_file):
        print(f"Found existing output file: {output_file}")
        with open(output_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data = json.loads(line)
                        if 'id' in data:
                            processed_ids.add(data['id'])
                    except json.JSONDecodeError:
                        print(f"Warning: Invalid JSON line in output file, skipped")
        print(f"Read {len(processed_ids)} already processed records from output file")
    else:
        # Create empty file
        open(output_file, 'w', encoding='utf-8').close()
        print(f"Created new output file: {output_file}")
    
    # Create thread locks
    file_lock = threading.Lock()  # File writing lock
    pbar_lock = threading.Lock()  # Progress bar update lock
    
    # Calculate number of items to process
    total_items = len(input_data)
    to_process = [data for data in input_data if data.get('id') not in processed_ids]
    print(f"Total of {total_items} items, {len(to_process)} to process, {len(processed_ids)} already processed")
    
    # If all items already processed, return
    if not to_process:
        print("All items already processed, no further action needed")
        return
    
    # Create progress bar
    with tqdm(total=total_items, desc="Cleaning articles", initial=len(processed_ids)) as pbar:
        # Multi-threaded processing
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Submit tasks
            futures = []
            for data in to_process:
                future = executor.submit(
                    clean_single_article, data, clean_agent, output_file, processed_ids, 
                    file_lock, pbar_lock, pbar, max_retries=max_retries, language=language
                )
                futures.append(future)
            
            # Wait for all tasks to complete
            processed_count = 0
            for future in concurrent.futures.as_completed(futures):
                result = future.result()
                if result:
                    processed_count += 1
    
    print(f"Processing complete! Cleaned {processed_count} new articles, total of {len(processed_ids)} articles processed")
    print(f"Results saved to {output_file}")

def clean_model(model, raw_data_dir, cleaned_data_dir, clean_agent, max_workers=5, max_retries=10, limit=None, language=None):
    """
    Clean articles for a single model
    
    Args:
        model: Model name
        raw_data_dir: Raw data directory
        cleaned_data_dir: Cleaned data directory
        clean_agent: LLM API instance
        max_workers: Maximum thread count
        max_retries: Maximum retry attempts
        limit: Limit on number of items to process
        language: Article language, if None will be determined from task
    """
    # Check if clean_agent is provided
    if not clean_agent:
        print(f"Error: clean_agent not provided, cannot clean articles")
        return
    
    # Load query.jsonl to get prompt to language mapping
    query_file = "/mnt/yscfs/dumingxuan/deepresearch_bench/data/prompt_data/query.jsonl"
    prompt_to_language = {}
    if os.path.exists(query_file):
        try:
            with open(query_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        try:
                            query = json.loads(line)
                            if 'prompt' in query and 'language' in query:
                                prompt_to_language[query['prompt']] = query['language']
                        except json.JSONDecodeError:
                            print(f"Warning: Error parsing JSON in query file, line content: {line.strip()}")
        except Exception as e:
            print(f"Warning: Error reading query file: {e}")
    
    # Ensure output directory exists
    os.makedirs(cleaned_data_dir, exist_ok=True)
    
    input_file = os.path.join(raw_data_dir, f"{model}.jsonl")
    output_file = os.path.join(cleaned_data_dir, f"{model}.jsonl")
    
    if os.path.exists(input_file):
        print(f"\n=== Cleaning {model} articles ===")
        
        # Read input data and determine language for each article based on prompt
        all_items = []
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        item = json.loads(line)
                        prompt = item.get('prompt', '')
                        # Add language attribute to each item for processing
                        item['_language'] = prompt_to_language.get(prompt, language or 'en')
                        all_items.append(item)
                    except json.JSONDecodeError:
                        print(f"Warning: Error parsing JSON in input file, line content: {line.strip()}")
        
        # Apply limit
        if limit is not None and limit > 0:
            all_items = all_items[:limit]
            
        # Create progress bar
        print(f"\n--- Cleaning all {model} articles ({len(all_items)}) ---")
        
        # Set of processed IDs for deduplication
        processed_ids = set()
        
        # If output file exists, read already processed IDs
        if os.path.exists(output_file):
            print(f"Found existing output file: {output_file}")
            with open(output_file, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        try:
                            data = json.loads(line)
                            if 'id' in data:
                                processed_ids.add(data['id'])
                        except json.JSONDecodeError:
                            print(f"Warning: Invalid JSON line in output file, skipped")
            print(f"Read {len(processed_ids)} already processed records from output file")
        else:
            # Create empty file
            open(output_file, 'w', encoding='utf-8').close()
            print(f"Created new output file: {output_file}")
            
        # Filter items that need processing
        to_process = [item for item in all_items if item.get('id') not in processed_ids]
        print(f"Total of {len(all_items)} items, {len(to_process)} to process, {len(processed_ids)} already processed")
        
        # If all items already processed, return
        if not to_process:
            print("All items already processed, no further action needed")
            return
            
        # Create thread locks
        file_lock = threading.Lock()  # File writing lock
        pbar_lock = threading.Lock()  # Progress bar update lock
        
        # Create progress bar
        with tqdm(total=len(all_items), desc=f"Cleaning {model} articles", initial=len(processed_ids)) as pbar:
            # Multi-threaded processing
            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                # Submit tasks
                futures = []
                for item in to_process:
                    # Use language property for cleaning
                    item_language = item.pop('_language', language or 'en')
                    future = executor.submit(
                        clean_single_article, item, clean_agent, output_file, processed_ids, 
                        file_lock, pbar_lock, pbar, max_retries=max_retries, language=item_language
                    )
                    futures.append(future)
                
                # Wait for all tasks to complete
                processed_count = 0
                for future in concurrent.futures.as_completed(futures):
                    result = future.result()
                    if result:
                        processed_count += 1
        
        print(f"\n=== {model} model cleaning complete, cleaned {processed_count} new articles, total of {len(processed_ids)} articles processed ===")
    else:
        print(f"Warning: Input file for model {model} not found: {input_file}")

