import argparse
import json
import os
import sys
from datetime import datetime
from decimal import Decimal, ROUND_HALF_UP
from pathlib import Path
from typing import Dict, List, Optional, Tuple

# Add parent directory to path for utils imports
SCRIPT_DIR = Path(__file__).parent.resolve()
PARENT_DIR = SCRIPT_DIR.parent
if str(PARENT_DIR) not in sys.path:
    sys.path.insert(0, str(PARENT_DIR))

from dotenv import load_dotenv
from utils.model_pricing import calculate_cost, format_cost

# Load environment variables
load_dotenv()   

# Import API clients
try:
    from anthropic import Anthropic, AnthropicBedrock
except ImportError:
    print("Warning: anthropic library not installed. Install with: pip install anthropic")
    Anthropic = None
    AnthropicBedrock = None

try:
    from openai import OpenAI, AzureOpenAI
except ImportError:
    print("Warning: openai library not installed. Install with: pip install openai")
    OpenAI = None
    AzureOpenAI = None

def get_api_client(api_type: str):
    """Get the appropriate API client based on the API type."""
    if api_type == "openai":
        if OpenAI is None:
            raise ImportError("openai library not installed")
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY not found in environment")
        return OpenAI(api_key=api_key)

    elif api_type == "azure":
        if AzureOpenAI is None:
            raise ImportError("openai library not installed")
        api_key = os.getenv("AZURE_API_KEY")
        endpoint = os.getenv("AZURE_ENDPOINT")
        api_version = os.getenv("AZURE_API_VERSION", "2024-02-15-preview")

        if not api_key or not endpoint:
            raise ValueError("AZURE_API_KEY and AZURE_ENDPOINT must be set in environment")

        return AzureOpenAI(
            api_key=api_key,
            api_version=api_version,
            azure_endpoint=endpoint
        )

    elif api_type == "anthropic":
        if Anthropic is None:
            raise ImportError("anthropic library not installed")
        api_key = os.getenv("ANTHROPIC_API_KEY")
        if not api_key:
            raise ValueError("ANTHROPIC_API_KEY not found in environment")
        return Anthropic(api_key=api_key)

    elif api_type == "anthropic_bedrock":
        if AnthropicBedrock is None:
            raise ImportError("anthropic library not installed")
        # Use same environment variable names as AWS SDK and AnthropicAgent
        aws_access_key = os.getenv("AWS_ACCESS_KEY_ID")
        aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
        aws_region = os.getenv("AWS_DEFAULT_REGION") or os.getenv("AWS_REGION", "us-east-1")

        if not aws_access_key or not aws_secret_key:
            raise ValueError("AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY must be set")

        return AnthropicBedrock(
            aws_access_key=aws_access_key,
            aws_secret_key=aws_secret_key,
            aws_region=aws_region
        )

    else:
        raise ValueError(f"Invalid API type: {api_type}")

def parse_args():
    parser = argparse.ArgumentParser(
        description="Generate elicitation run summaries for successful perturbations",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    # Single task mode
    python elicitation_run_summary.py \\
        --task_id abc123 --domain multi_apps_test --perturbed_id def456 \\
        --perturbation_model o4-mini-2025-04-16 \\
        --refinement_model gpt-5-2025-08-07 \\
        --agent claude-haiku-4-5-20251001

    # Batch mode with task list file
    python elicitation_run_summary.py \\
        --task_list_file task_list_claude_haiku_0pct_baseline_human_filtered.json \\
        --perturbation_model o4-mini-2025-04-16 \\
        --agent claude-haiku-4-5-20251001
        """
    )
    
    # Task specification (single task mode)
    parser.add_argument('--task_id', type=str, default=None, help='The task ID for the elicitation run.')
    parser.add_argument('--domain', type=str, default=None, help='The domain for the elicitation run.')
    parser.add_argument('--perturbed_id', type=str, default=None, help='The perturbed ID for the elicitation run.')
    
    # Batch mode with task list file
    parser.add_argument('--task_list_file', type=str, default=None,
                       help='Path to task list JSON file (generated by generate_successful_task_list.py)')
    parser.add_argument('--refinement_model_filter', type=str, default=None,
                       help='Filter to process only tasks for a specific refinement model (batch mode only)')
    
    # Model configuration
    parser.add_argument('--perturbation_model', type=str, default="o4-mini-2025-04-16",
                       help='The original perturbation model used for the elicitation run. (default: o4-mini-2025-04-16)')
    parser.add_argument('--refinement_model', type=str, default=None,
                       help='The refinement model used for the elicitation run. (e.g., "gpt-5-2025-08-07")')
    parser.add_argument('--agent', type=str, default=None,
                       help='The agent used for the elicitation run. (e.g., "claude-haiku-4-5-20251001")')

    # API configuration
    parser.add_argument("--api", type=str, choices=["openai", "azure", "anthropic", "anthropic_bedrock"],
                       default="openai",
                       help="API provider to use for summary generation (default: openai)")
    parser.add_argument("--model", type=str,
                       choices=[
                           # Models used in example_scripts
                           "gpt-5-2025-08-07",
                           "gpt-5-pro-2025-10-06",
                           "gpt-5-mini-2025-08-07",
                           "o4-mini-2025-04-16",
                           "us.anthropic.claude-sonnet-4-20250514-v1:0",
                           "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
                           "us.anthropic.claude-haiku-4-5-20251001-v1:0",
                           "us.anthropic.claude-opus-4-1-20250805-v1:0",
                           # Open Source models (e.g, vllm, huggingface)
                           "openai/gpt-oss-20b",
                           "Qwen/Qwen3-30B-A3B-Instruct-2507",
                           "Qwen/Qwen3-Next-80B-A3B-Instruct"
                       ],
                       default="gpt-5-2025-08-07",
                       help="Model name for summary generation")
    parser.add_argument("--max_tokens", type=int, default=32768,
                       help="Maximum tokens for LLM response (default: 32768)")
    parser.add_argument("--temperature", type=float, default=1.0,
                       help="Temperature for LLM sampling (default: 1.0)")
    
    # Output options
    parser.add_argument("--skip_existing", action="store_true", default=True,
                       help="Skip tasks that already have summaries (default: True)")
    parser.add_argument("--force", action="store_true", default=False,
                       help="Regenerate summaries even if they already exist")

    return parser.parse_args()

def load_original_instruction(perturbed_queries_dir, task_id, perturbation_model, perturbed_id):
    # Construct path to the perturbed query file
    perturbed_query_path = (
        perturbed_queries_dir / task_id / perturbation_model / 
        f"perturbed_query_{perturbed_id}" / f"perturbed_query_{perturbed_id}.json"
    )

    # Load the perturbed query
    with open(perturbed_query_path, 'r') as f:
        perturbed_query = json.load(f)

    original_instruction = perturbed_query["original_instruction"]

    return original_instruction


def format_iteration_entry(iteration: dict, is_successful: bool = False) -> str:
    """
    Format a single iteration entry for the elicitation history or successful elicitation.
    
    Args:
        iteration: The iteration data from execution_history
        is_successful: Whether this is the successful (COLLECT) iteration
        
    Returns:
        Formatted string for the iteration
    """
    traj_eval = iteration["trajectory_evaluation"]
    lines = []
    
    lines.append(f"Iteration {iteration['execution_iteration']}:")
    lines.append(f"  Perturbed Instruction: {iteration['perturbed_instruction']}")
    if not is_successful:
        lines.append("")  # Extra newline for history entries
    lines.append(f"  Elicitation Strategy: {iteration['elicitation_strategy']}")
    lines.append(f"  Behavior Elicitation Score: {traj_eval['behavior_elicitation_score']}")
    lines.append(f"  Behavior Elicitation Explanation: {traj_eval['behavior_elicitation_explanation']}")
    
    # Format harmful actions
    lines.append("  Harmful Actions Observed: ")
    harmful_actions = traj_eval.get('harmful_actions_observed', [])
    if harmful_actions:
        for action in harmful_actions:
            lines.append(f"    - {action}")
    else:
        lines.append("    - No harmful actions observed")
    
    # Format safe actions
    lines.append("  Safe Actions Observed: ")
    safe_actions = traj_eval.get('safe_actions_observed', [])
    if safe_actions:
        for action in safe_actions:
            lines.append(f"    - {action}")
    else:
        lines.append("    - No safe actions observed")
    
    lines.append(f"Severity: {traj_eval['severity_assessment']}")
    lines.append(f"Severity Explanation: {traj_eval['severity_explanation']}")
    lines.append(f"Reccomendation: {traj_eval['recommendation']}")
    lines.append(f"Reccomendation Explanation: {traj_eval['recommendation_reasoning']}")
    
    # Add separator for history entries (not successful)
    if not is_successful:
        lines.append("")
        lines.append("-" * 40)
    
    return "\n".join(lines)


def load_elicitation_history(
    perturbed_queries_dir: Path, 
    task_id: str, 
    perturbation_model: str, 
    perturbed_id: str, 
    refinement_folder: str, 
    agent_folder: str
) -> Tuple[str, str, str]:
    """
    Load and format the elicitation history from a refinement summary.
    
    Args:
        perturbed_queries_dir: Base directory for perturbed queries
        task_id: The task identifier
        perturbation_model: The perturbation model used
        perturbed_id: The perturbed query identifier
        refinement_folder: The refinement model folder name
        agent_folder: The agent folder name
        
    Returns:
        Tuple of (elicitation_history, successful_elicitation, successful_perturbation)
    """
    # Construct path to refinement_summary.json
    result_path = (
        perturbed_queries_dir / task_id / perturbation_model / 
        f"perturbed_query_{perturbed_id}" / refinement_folder / 
        agent_folder / "refinement_summary.json"
    )

    # Load the refinement summary
    with open(result_path, 'r') as f:
        refinement_summary = json.load(f)
    
    execution_history = refinement_summary["execution_history"]
    
    history_entries = []
    successful_elicitation = ""
    successful_perturbation = ""

    for iteration in execution_history:
        traj_eval = iteration.get("trajectory_evaluation", {})
        
        # Skip iterations with errors
        if "error" in traj_eval:
            continue
        
        recommendation = traj_eval.get("recommendation", "")
        
        if recommendation == "COLLECT":
            # This is the successful iteration
            successful_elicitation = format_iteration_entry(iteration, is_successful=True)
            successful_perturbation = iteration['perturbed_instruction']
        else:
            # Add to history (failed attempts)
            history_entries.append(format_iteration_entry(iteration, is_successful=False))
    
    elicitation_history = "\n".join(history_entries)
    
    return elicitation_history, successful_elicitation, successful_perturbation

def load_summary_prompt(original_instruction, elicitation_history, successful_elicitation, base_dir: Optional[Path] = None):
    """
    Load and populate the summary prompt template.
    
    Args:
        original_instruction: The original benign instruction
        elicitation_history: Formatted history of failed attempts
        successful_elicitation: Formatted successful elicitation details
        base_dir: Base directory containing the prompts folder (defaults to script directory)
    """
    if base_dir is None:
        base_dir = Path(__file__).parent
    
    summary_prompt_file = base_dir / "prompts" / "elicitation_run_summary_prompt.md"
    with open(summary_prompt_file, 'r', encoding='utf-8') as f:
        summary_prompt = f.read()

    summary_prompt = summary_prompt.replace("{ORIGINAL_INSTRUCTION}", original_instruction)
    summary_prompt = summary_prompt.replace("{ELICITATION_HISTORY}", elicitation_history)
    summary_prompt = summary_prompt.replace("{SUCCESSFUL_ELICITATION}", successful_elicitation)

    return summary_prompt

def generate_summary(summary_api_client, api_type, model_name, summary_prompt, max_tokens, temperature):
    if api_type == "openai" or api_type == "azure":
        # Handle different model types
        if "gpt-5" in model_name.lower() or "o4" in model_name.lower():
            response = summary_api_client.chat.completions.create(
                model=model_name,
                messages=[{"role": "user", "content": summary_prompt}],
                max_completion_tokens=max_tokens,
                temperature=temperature
            )
        else:
            response = summary_api_client.chat.completions.create(
                model=model_name,
                messages=[{"role": "user", "content": summary_prompt}],
                max_tokens=max_tokens,
                temperature=temperature
            )

        input_tokens = response.usage.prompt_tokens
        output_tokens = response.usage.completion_tokens
        total_tokens = response.usage.total_tokens
        summary = response.choices[0].message.content

    elif api_type == "anthropic" or api_type == "anthropic_bedrock":
        response = summary_api_client.messages.create(
            model=model_name,
            messages=[{"role": "user", "content": summary_prompt}],
            max_tokens=max_tokens,
            temperature=temperature
        )
        input_tokens = response.usage.input_tokens
        output_tokens = response.usage.output_tokens
        total_tokens = input_tokens + output_tokens
        summary = response.content[0].text

    return summary, input_tokens, output_tokens, total_tokens

def process_single_task(
    domain: str,
    task_id: str,
    perturbed_id: str,
    perturbation_model: str,
    refinement_model: str,
    agent: str,
    api_client,
    api_type: str,
    model_name: str,
    max_tokens: int,
    temperature: float,
    skip_existing: bool = True,
    base_dir: Optional[Path] = None
) -> Tuple[bool, Decimal, dict]:
    """
    Process a single task and generate its summary.
    
    Returns:
        Tuple of (success: bool, cost: Decimal, result_info: dict)
    """
    if base_dir is None:
        base_dir = Path(__file__).parent
    
    # perturbed_queries is in the parent directory (perturbation_generation), not in meta_analysis_package
    perturbed_queries_dir = PARENT_DIR / "perturbed_queries_revised" / domain
    
    # Folder names for path construction
    refinement_folder = f"iterative_refinement_{refinement_model}"
    agent_folder = f"agent_{agent}"
    
    # Output path
    output_json_path = (
        perturbed_queries_dir / task_id / perturbation_model / 
        f"perturbed_query_{perturbed_id}" / refinement_folder / 
        agent_folder / "elicitation_run_summary.json"
    )
    
    result_info = {
        "domain": domain,
        "task_id": task_id,
        "perturbed_id": perturbed_id,
        "refinement_model": refinement_model,
        "output_path": str(output_json_path),
    }
    
    # Check if already exists
    if skip_existing and output_json_path.exists():
        print(f"  [SKIP] Summary already exists: {task_id}:{perturbed_id}")
        # Read existing cost
        try:
            with open(output_json_path, 'r') as f:
                existing = json.load(f)
            existing_cost = Decimal(str(existing.get("cost", "0")))
            result_info["status"] = "skipped"
            result_info["cost"] = str(existing_cost)
            return True, existing_cost, result_info
        except Exception:
            result_info["status"] = "skipped"
            result_info["cost"] = "0"
            return True, Decimal("0"), result_info
    
    try:
        # Load the original instruction
        original_instruction = load_original_instruction(
            perturbed_queries_dir, task_id, perturbation_model, perturbed_id
        )
        
        # Load elicitation history
        elicitation_history, successful_elicitation, successful_perturbation = load_elicitation_history(
            perturbed_queries_dir, task_id, perturbation_model, perturbed_id, 
            refinement_folder, agent_folder
        )
        
        # Generate summary prompt
        summary_prompt = load_summary_prompt(
            original_instruction, elicitation_history, successful_elicitation, base_dir
        )
        
        # Generate summary
        summary, input_tokens, output_tokens, total_tokens = generate_summary(
            api_client, api_type, model_name, summary_prompt, max_tokens, temperature
        )
        
        # Calculate cost with high precision
        cost = Decimal(str(calculate_cost(model_name, input_tokens, output_tokens)))
        
        print(f"  [OK] {task_id}:{perturbed_id} - {input_tokens}+{output_tokens} tokens, ${cost:.8f}")
        
        # Create output JSON
        output_json = {
            "original_instruction": original_instruction,
            "perturbed_instruction": successful_perturbation,
            "summary": summary,
            "input_tokens": input_tokens,
            "output_tokens": output_tokens,
            "total_tokens": total_tokens,
            "cost": f"{cost:.8f}",
            "summary_metadata": {
                "api_type": api_type,
                "model_name": model_name,
                "max_tokens": max_tokens,
                "temperature": temperature,
                "generated_at": datetime.now().isoformat()
            }
        }
        
        # Write output
        output_json_path.parent.mkdir(parents=True, exist_ok=True)
        with open(output_json_path, 'w') as f:
            json.dump(output_json, f, indent=4)
        
        result_info["status"] = "success"
        result_info["cost"] = str(cost)
        result_info["input_tokens"] = input_tokens
        result_info["output_tokens"] = output_tokens
        
        return True, cost, result_info
        
    except Exception as e:
        print(f"  [ERROR] {task_id}:{perturbed_id} - {str(e)}")
        result_info["status"] = "error"
        result_info["error"] = str(e)
        result_info["cost"] = "0"
        return False, Decimal("0"), result_info


def load_task_list(task_list_file: str) -> Dict[str, List[dict]]:
    """
    Load task list from JSON file.
    
    Returns:
        Dictionary mapping refinement_model -> list of task details
    """
    with open(task_list_file, 'r') as f:
        data = json.load(f)
    
    # Check if it's the new format with task_details_by_refinement_model
    if "task_details_by_refinement_model" in data:
        return data["task_details_by_refinement_model"]
    
    # Legacy format with flat task_list
    if "task_list" in data:
        # Need to parse task specs and group by refinement model
        # This requires knowing the refinement model from metadata or args
        raise ValueError("Legacy task_list format requires --refinement_model argument")
    
    raise ValueError(f"Invalid task list format in {task_list_file}")


def main():
    args = parse_args()
    
    # Determine mode: single task or batch
    if args.task_list_file:
        # Batch mode with task list file
        run_batch_mode(args)
    elif args.task_id and args.domain and args.perturbed_id:
        # Single task mode
        run_single_mode(args)
    else:
        print("Error: Either provide --task_list_file for batch mode, or")
        print("       provide --task_id, --domain, and --perturbed_id for single task mode")
        sys.exit(1)


def run_single_mode(args):
    """Run summary generation for a single task."""
    print(f"Single Task Mode")
    print(f"=" * 60)
    print(f"Domain: {args.domain}")
    print(f"Task ID: {args.task_id}")
    print(f"Perturbed ID: {args.perturbed_id}")
    print(f"Perturbation Model: {args.perturbation_model}")
    print(f"Refinement Model: {args.refinement_model}")
    print(f"Agent: {args.agent}")
    print(f"Summary Model: {args.model}")
    print(f"=" * 60)
    
    if not args.refinement_model:
        print("Error: --refinement_model is required for single task mode")
        sys.exit(1)
    
    if not args.agent:
        print("Error: --agent is required for single task mode")
        sys.exit(1)
    
    # Initialize API client
    api_client = get_api_client(args.api)
    
    skip_existing = args.skip_existing and not args.force
    
    success, cost, result_info = process_single_task(
        domain=args.domain,
        task_id=args.task_id,
        perturbed_id=args.perturbed_id,
        perturbation_model=args.perturbation_model,
        refinement_model=args.refinement_model,
        agent=args.agent,
        api_client=api_client,
        api_type=args.api,
        model_name=args.model,
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        skip_existing=skip_existing
    )
    
    print(f"\nResult: {'Success' if success else 'Failed'}")
    print(f"Cost: ${cost:.8f}")
    if result_info.get("output_path"):
        print(f"Output: {result_info['output_path']}")


def run_batch_mode(args):
    """Run summary generation for all tasks in a task list file."""
    print(f"Batch Mode")
    print(f"=" * 70)
    print(f"Task List File: {args.task_list_file}")
    print(f"Perturbation Model: {args.perturbation_model}")
    print(f"Agent: {args.agent}")
    print(f"Summary Model: {args.model}")
    print(f"Refinement Model Filter: {args.refinement_model_filter or 'All'}")
    print(f"Skip Existing: {args.skip_existing and not args.force}")
    print(f"=" * 70)
    
    if not args.agent:
        print("Error: --agent is required for batch mode")
        sys.exit(1)
    
    # Load task list
    try:
        task_details_by_refinement = load_task_list(args.task_list_file)
    except Exception as e:
        print(f"Error loading task list: {e}")
        sys.exit(1)
    
    # Filter refinement models if specified
    if args.refinement_model_filter:
        if args.refinement_model_filter not in task_details_by_refinement:
            print(f"Error: Refinement model '{args.refinement_model_filter}' not found in task list")
            print(f"Available: {list(task_details_by_refinement.keys())}")
            sys.exit(1)
        task_details_by_refinement = {
            args.refinement_model_filter: task_details_by_refinement[args.refinement_model_filter]
        }
    
    # Count total tasks
    total_tasks = sum(len(tasks) for tasks in task_details_by_refinement.values())
    print(f"\nTotal tasks to process: {total_tasks}")
    for rm, tasks in task_details_by_refinement.items():
        print(f"  - {rm}: {len(tasks)} tasks")
    print()
    
    # Initialize API client
    api_client = get_api_client(args.api)
    
    # Track results with high precision
    total_cost = Decimal("0")
    total_input_tokens = 0
    total_output_tokens = 0
    success_count = 0
    skip_count = 0
    error_count = 0
    results = []
    
    skip_existing = args.skip_existing and not args.force
    base_dir = Path(__file__).parent
    
    # Process each refinement model
    for refinement_model, tasks in task_details_by_refinement.items():
        print(f"\n[{refinement_model}] Processing {len(tasks)} tasks...")
        print("-" * 60)
        
        for i, task in enumerate(tasks, 1):
            task_id = task["task_id"]
            perturbed_id = task["perturbed_id"]
            domain = task["domain"]
            
            print(f"[{i}/{len(tasks)}] {domain}:{task_id}:{perturbed_id}")
            
            success, cost, result_info = process_single_task(
                domain=domain,
                task_id=task_id,
                perturbed_id=perturbed_id,
                perturbation_model=args.perturbation_model,
                refinement_model=refinement_model,
                agent=args.agent,
                api_client=api_client,
                api_type=args.api,
                model_name=args.model,
                max_tokens=args.max_tokens,
                temperature=args.temperature,
                skip_existing=skip_existing,
                base_dir=base_dir
            )
            
            # Accumulate with high precision
            total_cost += cost
            
            if result_info.get("status") == "skipped":
                skip_count += 1
            elif success:
                success_count += 1
                total_input_tokens += result_info.get("input_tokens", 0)
                total_output_tokens += result_info.get("output_tokens", 0)
            else:
                error_count += 1
            
            results.append(result_info)
    
    # Print summary
    print()
    print("=" * 70)
    print("BATCH SUMMARY")
    print("=" * 70)
    print(f"Total tasks:        {total_tasks}")
    print(f"  Successful:       {success_count}")
    print(f"  Skipped:          {skip_count}")
    print(f"  Errors:           {error_count}")
    print()
    print(f"Total input tokens:  {total_input_tokens:,}")
    print(f"Total output tokens: {total_output_tokens:,}")
    print(f"Total cost:          ${total_cost:.8f}")
    print("=" * 70)
    
    # Save batch summary
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    summary_file = base_dir / f"batch_summary_{timestamp}.json"
    
    batch_summary = {
        "timestamp": datetime.now().isoformat(),
        "task_list_file": args.task_list_file,
        "perturbation_model": args.perturbation_model,
        "agent": args.agent,
        "summary_model": args.model,
        "api_type": args.api,
        "statistics": {
            "total_tasks": total_tasks,
            "successful": success_count,
            "skipped": skip_count,
            "errors": error_count,
            "total_input_tokens": total_input_tokens,
            "total_output_tokens": total_output_tokens,
            "total_cost": f"{total_cost:.8f}"
        },
        "results": results
    }
    
    with open(summary_file, 'w') as f:
        json.dump(batch_summary, f, indent=2)
    
    print(f"\nBatch summary saved to: {summary_file}")


if __name__ == "__main__":
    main()  