#!/usr/bin/env python3
"""
Compute token compression rate between the two system prompts:
1. Regular prompt (with policy document included)
2. Internalization prompt (without policy document)
"""

import os
import sys
import importlib
from transformers import AutoTokenizer


def load_policy_document(split="retail"):
    """Load policy document from wiki.md"""
    tau_bench_root = "/code/jiateng-sandbox/taubench_application/tau-bench"
    domain_path = os.path.join(tau_bench_root, "tau_bench", "envs", split)
    wiki_path = os.path.join(domain_path, "wiki.md")
    
    try:
        with open(wiki_path, 'r') as f:
            policy = f.read()
        print(f"Loaded policy from {wiki_path}")
        return policy
    except Exception as e:
        print(f"Error loading policy: {e}")
        return ""


def load_tools_and_get_specifications(split="retail"):
    """Load tools and get their specifications"""
    tau_bench_root = "/code/jiateng-sandbox/taubench_application/tau-bench"
    domain_path = os.path.join(tau_bench_root, "tau_bench", "envs", split)
    
    # Add domain to Python path for imports
    if domain_path not in sys.path:
        sys.path.append(domain_path)
    if tau_bench_root not in sys.path:
        sys.path.append(tau_bench_root)
    
    tools_module_path = f"tau_bench.envs.{split}.tools"
    try:
        tools_module = importlib.import_module(tools_module_path)
        tools_map = {}
        
        for tool_class in tools_module.ALL_TOOLS:
            tool_name = tool_class.__name__
            tools_map[tool_name] = tool_class
        
        # Get tool specifications
        specifications = []
        for tool_name, tool_class in tools_map.items():
            try:
                info = tool_class.get_info()
                func_info = info.get('function', {})
                name = func_info.get('name', tool_name)
                description = func_info.get('description', 'No description available')
                parameters = func_info.get('parameters', {})
                
                spec = f"Tool: {name}\n"
                spec += f"Description: {description}\n"
                
                if parameters and 'properties' in parameters:
                    spec += "Parameters:\n"
                    for param_name, param_info in parameters['properties'].items():
                        param_type = param_info.get('type', 'unknown')
                        param_desc = param_info.get('description', 'No description')
                        required = param_name in parameters.get('required', [])
                        spec += f"  - {param_name} ({param_type}{'*' if required else ''}): {param_desc}\n"
                
                specifications.append(spec)
            except Exception as e:
                print(f"Error getting info for tool {tool_name}: {e}")
        
        return "\n".join(specifications)
    except Exception as e:
        print(f"Error loading tools: {e}")
        return ""


def build_regular_system_prompt(policy_document, tool_specifications):
    """Build the regular system prompt (with policy document)"""
    template_path = "/code/jiateng-sandbox/taubench_application/Single_turn_tau_bench/prompt_template.txt"
    
    try:
        with open(template_path, 'r') as f:
            template = f.read()
    except Exception as e:
        print(f"Error loading prompt template: {e}")
        return ""
    
    # Replace placeholders
    system_prompt = template.replace("{Policy Document}", policy_document)
    system_prompt = system_prompt.replace("{Tool Specifications}", tool_specifications)
    
    return system_prompt


def build_internalization_system_prompt(tool_specifications):
    """Build the internalization system prompt (without policy document)"""
    template_path = "/code/jiateng-sandbox/taubench_application/Single_turn_tau_bench/prompt_template_none.txt"
    
    try:
        with open(template_path, 'r') as f:
            template = f.read()
    except Exception as e:
        print(f"Error loading prompt template: {e}")
        return ""
    
    # Replace placeholders (no policy document)
    system_prompt = template.replace("{Tool Specifications}", tool_specifications)
    
    return system_prompt


def compute_policy_document_compression(tokenizer, policy_document):
    """Compute compression rate of the policy document itself"""
    # Full policy document
    full_policy = policy_document
    
    # Compressed reference (from the internalization prompt template)
    compressed_policy_ref = "### Retail-Policy-Document (Tau-bench) ###"
    
    # Tokenize both
    full_policy_tokens = tokenizer.encode(full_policy)
    compressed_policy_tokens = tokenizer.encode(compressed_policy_ref)
    
    full_policy_token_count = len(full_policy_tokens)
    compressed_policy_token_count = len(compressed_policy_tokens)
    
    # Calculate compression rate
    policy_compression_rate = (1 - compressed_policy_token_count / full_policy_token_count) * 100
    
    print("="*60)
    print("POLICY DOCUMENT COMPRESSION ANALYSIS")
    print("="*60)
    print(f"Full policy document:")
    print(f"  Token count: {full_policy_token_count:,}")
    print(f"  Character count: {len(full_policy):,}")
    print()
    print(f"Compressed policy reference:")
    print(f"  Text: '{compressed_policy_ref}'")
    print(f"  Token count: {compressed_policy_token_count:,}")
    print(f"  Character count: {len(compressed_policy_ref):,}")
    print()
    print(f"Policy token reduction: {full_policy_token_count - compressed_policy_token_count:,} tokens")
    print(f"Policy compression rate: {policy_compression_rate:.2f}%")
    print("="*60)
    
    return policy_compression_rate


def compute_compression_rate():
    """Compute token compression rate between the two system prompts"""
    print("Loading Qwen-3-32B tokenizer...")
    model_name = "Qwen/Qwen3-32B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    print("Tokenizer loaded successfully\n")
    
    # Load policy document and tool specifications
    print("Loading policy document and tool specifications...")
    policy_document = load_policy_document("retail")
    tool_specifications = load_tools_and_get_specifications("retail")
    print("Data loaded successfully\n")
    
    # Compute policy document compression first
    policy_compression_rate = compute_policy_document_compression(tokenizer, policy_document)
    print()
    
    # Build both system prompts
    print("Building system prompts...")
    regular_prompt = build_regular_system_prompt(policy_document, tool_specifications)
    internalization_prompt = build_internalization_system_prompt(tool_specifications)
    
    # Tokenize both prompts
    print("Tokenizing prompts...")
    regular_tokens = tokenizer.encode(regular_prompt)
    internalization_tokens = tokenizer.encode(internalization_prompt)
    
    regular_token_count = len(regular_tokens)
    internalization_token_count = len(internalization_tokens)
    
    # Calculate compression rate
    compression_rate = (1 - internalization_token_count / regular_token_count) * 100
    
    # Print results
    print("="*60)
    print("FULL SYSTEM PROMPT COMPRESSION ANALYSIS")
    print("="*60)
    print(f"Regular prompt (with policy document):")
    print(f"  Token count: {regular_token_count:,}")
    print(f"  Character count: {len(regular_prompt):,}")
    print()
    print(f"Internalization prompt (without policy document):")
    print(f"  Token count: {internalization_token_count:,}")
    print(f"  Character count: {len(internalization_prompt):,}")
    print()
    print(f"Total token reduction: {regular_token_count - internalization_token_count:,} tokens")
    print(f"Full system prompt compression rate: {compression_rate:.2f}%")
    print("="*60)
    
    # Save prompts for inspection (optional)
    with open("regular_prompt.txt", "w") as f:
        f.write(regular_prompt)
    with open("internalization_prompt.txt", "w") as f:
        f.write(internalization_prompt)
    print("\nPrompts saved to regular_prompt.txt and internalization_prompt.txt for inspection.")
    
    return compression_rate, policy_compression_rate


if __name__ == "__main__":
    system_compression_rate, policy_compression_rate = compute_compression_rate()
    print(f"\n" + "="*60)
    print("SUMMARY")
    print("="*60)
    print(f"Policy document compression rate: {policy_compression_rate:.2f}%")
    print(f"Full system prompt compression rate: {system_compression_rate:.2f}%")
    print("="*60)
