#!/usr/bin/env python3
"""
Unified Agent Runner for DDR_Bench.

Single entry point for running data analysis agents across all scenarios:
- MIMIC: Patient data analysis using MIMIC-IV database
- 10-K: Financial report analysis using SEC 10-K filings
- GLOBEM: Behavioral data analysis using GLOBEM dataset

Usage:
    python run_agent.py --scenario mimic --db-path /path/to/mimic_iv.db --input /path/to/notes.json
    python run_agent.py --scenario 10k --db-path /path/to/10k.db
    python run_agent.py --scenario globem --data-path /path/to/globem/data

See README.md for detailed usage instructions.
"""

import argparse
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Set

from config import Config, get_config
from base_batch_analyzer import BaseBatchAnalyzer


class PatientBatchAnalyzer(BaseBatchAnalyzer):
    """Batch analyzer for MIMIC patient data."""
    
    def __init__(self, base_log_dir: str, model_name: str, vllm_port: int,
                 target_ids: Optional[Set[str]] = None, overwrite: bool = False,
                 llm_provider: str = None):
        super().__init__(base_log_dir, target_ids, overwrite)
        self.model_name = model_name
        self.vllm_port = vllm_port
        self.llm_provider = llm_provider

    def extract_identifiers(self, source_file: Path) -> List[Dict[str, Any]]:
        """Extract patient identifiers from pre-defined ID list file.
        
        The source_file should be a JSON file containing a list of patient IDs.
        This ensures we only process the intended pre-filtered samples.
        """
        try:
            print(f"Reading patient IDs from: {source_file}")
            with open(source_file, 'r', encoding='utf-8') as f:
                patient_ids = json.load(f)
            
            if not isinstance(patient_ids, list):
                print(f"   Error: Expected a list of patient IDs")
                return []
            
            patients_list = []
            for pid in patient_ids:
                subject_id = str(pid)
                patients_list.append({
                    "patient_id": f"patient_{subject_id}",
                    "subject_id": subject_id,
                    "identifier": subject_id,
                    "data": {}
                })
            
            print(f"   Found {len(patients_list)} patients")
            return patients_list
            
        except Exception as e:
            print(f"   Error reading file: {e}")
            return []
    
    def _prepare_analysis_command(self, identifier_info: Dict[str, Any], source_file: Path,
                                  subdir_name: str, db_path: str = None, **kwargs) -> tuple:
        """Prepare the command for patient analysis."""
        subject_id = identifier_info["subject_id"]
        
        patient_log_dir = self.base_log_dir / subdir_name
        patient_log_dir.mkdir(parents=True, exist_ok=True)
        
        task = f"Analyze patient {subject_id}"
        
        cmd = [
            sys.executable,
            "agent/data_agent.py",
            "--llm-provider", self.llm_provider,
            "--model", self.model_name,
            "--base-url", f"http://localhost:{self.vllm_port}",
            "--task", task,
            "--log-dir", str(patient_log_dir),
            "--max-turns", str(kwargs.get("max_turns", 100))
        ]
        
        if not kwargs.get("auto_finish", True):
            cmd.append("--no-auto-finish")
        
        if not db_path:
            raise ValueError("db_path is required")
        cmd.extend(["--sql-server", "tool_server/sqlite_mcp.py", "--data-path", db_path])
        
        env = os.environ.copy()
        env['CUSTOM_LOG_DIR'] = str(patient_log_dir)
        
        return cmd, env, f"Patient {subject_id}"
    
    def get_subdir_name(self, identifier: str) -> str:
        """Get subdirectory name for the patient identifier."""
        return f"patient_{identifier}"
    
    def _create_identifier_from_logs(self, identifier: str, dirname: str) -> Optional[Dict[str, Any]]:
        """Create patient identifier structure from log directory."""
        return {
            "patient_id": f"patient_{identifier}",
            "subject_id": identifier,
            "identifier": identifier,
            "data": {}
        }


class CompanyBatchAnalyzer(BaseBatchAnalyzer):
    """Batch analyzer for 10-K company data."""
    
    def __init__(self, base_log_dir: str, db_path: str,
                 target_ids: Optional[Set[str]] = None, overwrite: bool = False):
        super().__init__(base_log_dir, target_ids, overwrite)
        self.db_path = db_path
        self.model_name = ""
        self.llm_provider = ""
        self.vllm_port = 8000

    def extract_identifiers(self, source_file: Path) -> List[Dict[str, Any]]:
        """Extract company identifiers (CIKs) from pre-defined ID list file.
        
        The source_file should be a JSON file containing a list of company CIKs.
        This ensures we only process the intended pre-filtered samples.
        """
        companies = []
        try:
            print(f"Reading company CIKs from: {source_file}")
            with open(source_file, 'r', encoding='utf-8') as f:
                ciks = json.load(f)
            
            if not isinstance(ciks, list):
                print(f"   Error: Expected a list of company CIKs")
                return []
            
            for cik in ciks:
                companies.append({
                    "cik": str(cik),
                    "identifier": str(cik)
                })
            
            print(f"Found {len(companies)} companies")
            
        except Exception as e:
            print(f"Error reading ID file: {e}")
        
        return companies
    
    def _prepare_analysis_command(self, identifier_info: Dict[str, Any], source_file: Path,
                                  subdir_name: str, model: str = "", llm_provider: str = "",
                                  vllm_port: int = 8000, db_path: str = None, **kwargs) -> tuple:
        """Prepare the command for company analysis."""
        cik = identifier_info["cik"]
        
        company_log_dir = self.base_log_dir / subdir_name
        company_log_dir.mkdir(parents=True, exist_ok=True)
        
        task = f"Analyze company with CIK {cik}"
        
        cmd = [
            sys.executable,
            "agent/data_agent.py",
            "--llm-provider", llm_provider or self.llm_provider,
            "--model", model or self.model_name,
            "--base-url", f"http://localhost:{vllm_port}",
            "--task", task,
            "--log-dir", str(company_log_dir),
            "--max-turns", str(kwargs.get("max_turns", 100))
        ]
        
        if not kwargs.get("auto_finish", True):
            cmd.append("--no-auto-finish")
        
        if db_path:
            cmd.extend(["--sql-server", "tool_server/sqlite_mcp.py", "--data-path", db_path])
        
        env = os.environ.copy()
        env['CUSTOM_LOG_DIR'] = str(company_log_dir)
        
        return cmd, env, f"Company CIK {cik}"
    
    def get_subdir_name(self, identifier: str) -> str:
        """Get subdirectory name for the company identifier."""
        return f"company_{identifier}"
    
    def _create_identifier_from_logs(self, identifier: str, dirname: str) -> Optional[Dict[str, Any]]:
        """Create company identifier structure from log directory."""
        return {
            "cik": identifier,
            "identifier": identifier
        }


class UserBatchAnalyzer(BaseBatchAnalyzer):
    """Batch analyzer for GLOBEM user data."""
    
    def __init__(self, base_log_dir: str, data_path: str, vllm_model: str, vllm_port: int,
                 target_ids: Optional[Set[str]] = None, overwrite: bool = False,
                 llm_provider: str = None):
        super().__init__(base_log_dir, target_ids, overwrite)
        self.data_path = data_path
        self.vllm_model = vllm_model
        self.vllm_port = vllm_port
        self.llm_provider = llm_provider

    def extract_identifiers(self, source_file: Path) -> List[Dict[str, Any]]:
        """Extract user identifiers from pre-defined ID list file.
        
        The source_file should be a JSON file containing a list of user IDs.
        This ensures we only process the intended pre-filtered samples.
        """
        users = []
        try:
            print(f"Reading user IDs from: {source_file}")
            with open(source_file, 'r', encoding='utf-8') as f:
                user_ids = json.load(f)
            
            if not isinstance(user_ids, list):
                print(f"   Error: Expected a list of user IDs")
                return []
            
            for pid in user_ids:
                users.append({
                    "pid": str(pid),
                    "identifier": str(pid)
                })
            
            print(f"Found {len(users)} users")
            
        except Exception as e:
            print(f"Error reading ID file: {e}")
        
        return users
    
    def _prepare_analysis_command(self, identifier_info: Dict[str, Any], source_file: Path,
                                  subdir_name: str, data_path: str = None, **kwargs) -> tuple:
        """Prepare the command for user analysis."""
        pid = identifier_info["pid"]
        
        user_log_dir = self.base_log_dir / subdir_name
        user_log_dir.mkdir(parents=True, exist_ok=True)
        
        task = f"Analyze user {pid}"
        
        cmd = [
            sys.executable,
            "agent/data_agent.py",
            "--llm-provider", self.llm_provider,
            "--model", self.vllm_model,
            "--base-url", f"http://localhost:{self.vllm_port}",
            "--task", task,
            "--log-dir", str(user_log_dir),
            "--max-turns", str(kwargs.get("max_turns", 100))
        ]
        
        if not kwargs.get("auto_finish", True):
            cmd.append("--no-auto-finish")
        
        if data_path:
            cmd.extend(["--code-server", "tool_server/code_mcp.py", "--data-path", data_path])
        
        env = os.environ.copy()
        env['CUSTOM_LOG_DIR'] = str(user_log_dir)
        
        return cmd, env, f"User {pid}"
    
    def get_subdir_name(self, identifier: str) -> str:
        """Get subdirectory name for the user identifier."""
        return f"user_{identifier}"
    
    def _create_identifier_from_logs(self, identifier: str, dirname: str) -> Optional[Dict[str, Any]]:
        """Create user identifier structure from log directory."""
        return {
            "pid": identifier,
            "identifier": identifier
        }


def main():
    """Main entry point for running the agent."""
    parser = argparse.ArgumentParser(
        description="DDR_Bench Unified Agent Runner",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Run MIMIC patient analysis
  python run_agent.py --scenario mimic --db-path /path/to/mimic_iv.db --input /path/to/notes.json

  # Run 10-K company analysis
  python run_agent.py --scenario 10k --db-path /path/to/10k.db

  # Run GLOBEM user analysis
  python run_agent.py --scenario globem --data-path /path/to/globem/data
        """
    )
    
    # Required arguments
    parser.add_argument("--scenario", required=True, choices=["mimic", "10k", "globem"],
                        help="Analysis scenario to run")
    
    # Data paths (scenario-specific)
    parser.add_argument("--db-path", help="Path to SQLite database (mimic, 10k)")
    parser.add_argument("--data-path", help="Path to data directory (globem)")
    parser.add_argument("--input", help="Path to input JSON file (mimic)")
    parser.add_argument("--qa-file", help="Path to QA file for evaluation reference")
    
    # Output configuration
    parser.add_argument("--log-dir", help="Output log directory")
    
    # Provider configuration
    parser.add_argument("--provider",
                        choices=["gemini", "vllm", "minimax", "openai"],
                        help="LLM provider to use (default: from config)")
    parser.add_argument("--model", help="Model name to use")
    parser.add_argument("--api-key", help="API key (or set via environment variable)")
    parser.add_argument("--vllm-port", type=int, help="VLLM server port")
    
    # Agent configuration
    parser.add_argument("--max-turns", type=int, help="Maximum conversation turns")
    parser.add_argument("--max-retries", type=int, help="Maximum retry attempts")
    parser.add_argument("--max-samples", type=int, help="Limit number of samples to process")
    
    # Execution options
    parser.add_argument("--target-ids", help="Comma-separated list of specific IDs to process")
    parser.add_argument("--overwrite", action="store_true", help="Overwrite existing results")
    parser.add_argument("--retry-only", action="store_true", help="Only retry failed analyses")
    
    # Configuration file
    parser.add_argument("--config", help="Path to config.yaml file")
    
    args = parser.parse_args()
    
    # Load configuration
    config = get_config(args.config)
    scenario_config = config.get_scenario(args.scenario)
    
    # Override with CLI arguments
    log_dir = args.log_dir or scenario_config.log_dir
    db_path = args.db_path or scenario_config.db_path
    data_path = args.data_path or scenario_config.data_path
    id_file = scenario_config.id_file  # Pre-defined identifier list file
    
    model = args.model or config.provider.default_model
    provider = args.provider or config.provider.default_provider
    vllm_port = args.vllm_port or config.provider.vllm_port or 8000
    max_turns = args.max_turns or config.agent.max_turns or 100
    max_retries = args.max_retries or config.agent.max_retries or 2
    auto_finish = config.agent.auto_finish if hasattr(config.agent, 'auto_finish') else True
    log_level = config.agent.log_level or "INFO"
    
    # Set log level for subprocesses and current process
    os.environ["DDR_LOG_LEVEL"] = log_level
    
    # Process target IDs
    target_ids = None
    if args.target_ids:
        target_ids = set(id.strip() for id in args.target_ids.split(',') if id.strip())
        print(f"Target IDs: {sorted(target_ids)}")
    
    # Validate id_file exists
    if not id_file or not Path(id_file).exists():
        parser.error(f"ID file not found: {id_file}. Please check config.yaml.")
    
    # Create analyzer based on scenario
    if args.scenario == "mimic":
        if not db_path:
            parser.error("--db-path is required for mimic scenario")
        
        analyzer = PatientBatchAnalyzer(
            log_dir, model, vllm_port,
            target_ids, args.overwrite, provider
        )
        source_file = Path(id_file)
        run_kwargs = {"db_path": db_path, "max_turns": max_turns, "auto_finish": auto_finish}
        
    elif args.scenario == "10k":
        if not db_path:
            parser.error("--db-path is required for 10k scenario")
        
        analyzer = CompanyBatchAnalyzer(
            log_dir, db_path,
            target_ids, args.overwrite
        )
        analyzer.model_name = model
        analyzer.llm_provider = provider
        analyzer.vllm_port = vllm_port
        source_file = Path(id_file)
        run_kwargs = {"model": model, "llm_provider": provider,
                      "vllm_port": vllm_port, "db_path": db_path,
                      "max_turns": max_turns, "auto_finish": auto_finish}
        
    elif args.scenario == "globem":
        if not data_path:
            parser.error("--data-path is required for globem scenario")
        
        analyzer = UserBatchAnalyzer(
            log_dir, data_path, model, vllm_port,
            target_ids, args.overwrite, provider
        )
        source_file = Path(id_file)
        run_kwargs = {"data_path": data_path, "max_turns": max_turns, "auto_finish": auto_finish}
    
    # Run analysis
    print(f"\n{'='*60}")
    print(f"DDR_Bench Agent Runner")
    print(f"Scenario: {args.scenario}")
    print(f"Provider: {provider}")
    print(f"Model: {model}")
    print(f"Log Directory: {log_dir}")
    print(f"{'='*60}\n")
    
    if args.retry_only:
        analyzer.retry_failed_analyses(max_retries=max_retries, **run_kwargs)
    else:
        analyzer.run_batch_analysis(source_file, max_retries=max_retries, **run_kwargs)


if __name__ == "__main__":
    main()
