import os
import json
import argparse
import glob
import concurrent.futures
from pathlib import Path
from typing import Dict, Any, List

import sys
sys.path.append(str(Path(__file__).resolve().parents[1]))

from agents import SummarizeAgent
from agents.backends import create_llm_client
from utils import setup_logging, read_jsonl


def save_json(path: str, data: Dict[str, Any]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Preprocess documents: run summarization and cache results")
    parser.add_argument("--input", nargs='+',
                       help="JSONL file path(s) containing problem text and optional id. Can specify multiple files.")
    parser.add_argument("--provider", choices=["anthropic", "openai", "vllm", "deepseek"], default="anthropic")
    parser.add_argument("--model")
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--max_tokens", type=int, default=65536)
    parser.add_argument("--preprocess_dir", default="./preprocessed", help="Directory to store preprocessed outputs")
    parser.add_argument("--max_workers", type=int, default=8, help="Maximum number of worker threads")
    parser.add_argument(
        "--tokenizer_path",
        default="your_tokenizer_path",
        help="Optional tokenizer path for local models",
    )
    return parser.parse_args()


def load_all_problems(input_paths: List[str]) -> List[Dict[str, Any]]:
    """Load problems from multiple JSONL files"""
    all_problems = []
    
    for input_path in input_paths:
        file_paths = glob.glob(input_path) if '*' in input_path else [input_path]
        
        for file_path in file_paths:
            if os.path.exists(file_path):
                print(f"[PREPROCESS] Loading {file_path}")
                problems = read_jsonl(file_path)
                print(f"[PREPROCESS] Found {len(problems)} problems")
                all_problems.extend(problems)
            else:
                print(f"[PREPROCESS] File not found: {file_path}")
    
    return all_problems


def process_single_task(problem: Dict[str, Any], args) -> str:
    """Process a single task with retry mechanism"""
    task_name = problem["task"]
    document = problem["problem"]

    out_dir = os.path.join(args.preprocess_dir, str(task_name))
    os.makedirs(out_dir, exist_ok=True)

    section_path = os.path.join(out_dir, "section_dict.json")
    summary_path = os.path.join(out_dir, "summary_dict.json")

    # Skip if both exist
    if os.path.exists(section_path) and os.path.exists(summary_path):
        return f"[PREPROCESS] Skip {task_name}: already exists"

    # Create LLM client for this thread
    llm_client = create_llm_client(
        provider=args.provider,
        model=args.model,
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        tokenizer_path=args.tokenizer_path,
    )
    
    logger = setup_logging()
    summarize_agent = SummarizeAgent(llm_client, logger)

    print(f"[PREPROCESS] Summarizing {task_name} (len={len(document)})")
    
    # Retry mechanism for summarization
    max_retries = 5
    for attempt in range(1, max_retries + 1):
        try:
            result = summarize_agent.run({"document": document, "task_name": task_name})
            save_json(section_path, result["section_dict"])
            save_json(summary_path, result["summary_dict"])
            return f"[PREPROCESS] Saved: {section_path}, {summary_path}"
        except Exception as e:
            if attempt < max_retries:
                print(f"[PREPROCESS] Attempt {attempt} failed for {task_name}: {e}, retrying...")
                continue
            else:
                error_msg = f"[PREPROCESS] Failed after {max_retries} attempts for {task_name}: {e}"
                print(error_msg)
                raise Exception(error_msg)


def main():
    args = parse_args()
    logger = setup_logging()

    # Load all problems from input files
    problems = load_all_problems(args.input)
    print(f"[PREPROCESS] Total loaded: {len(problems)} problems")

    if not problems:
        print("[PREPROCESS] No problems found in input files")
        return

    print(f"[PREPROCESS] Starting preprocessing with {args.max_workers} workers")

    # Process problems using thread pool
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
        # Submit all tasks
        future_to_problem = {
            executor.submit(process_single_task, problem, args): problem 
            for problem in problems
        }
        
        # Collect results as they complete
        completed = 0
        total = len(problems)
        
        for future in concurrent.futures.as_completed(future_to_problem):
            problem = future_to_problem[future]
            try:
                result = future.result()
                print(result)
                completed += 1
                print(f"[PREPROCESS] Progress: {completed}/{total} ({completed/total*100:.1f}%)")
            except Exception as e:
                print(f"[PREPROCESS] Error processing {problem['task']}: {e}")
                completed += 1

    print(f"[PREPROCESS] Completed preprocessing of {len(problems)} problems")


if __name__ == "__main__":
    main()


