#!/usr/bin/env python3
"""
Script to process LinearizeLLM data using the LinearizeLLMWorkflow.
Reads .tex files from data/LinearizeLLM_data/instances_linearizellm/ and processes them.
"""

import os
import json
import sys
import argparse
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from src.core.agent_pipeline import LinearizeLLMWorkflow
from pathlib import Path

def load_linearizellm_tex_files(instances_dir):
    """
    Load LinearizeLLM .tex files from the instances directory.
    Each problem has its own subfolder containing both a .tex file and parameters.json.
    
    Args:
        instances_dir (Path): Path to the instances_linearizellm directory
        
    Returns:
        list: List of .tex file paths
    """
    tex_files = []
    
    # Look for problem subfolders
    for item in instances_dir.iterdir():
        if item.is_dir():
            # Each subfolder should contain a .tex file and parameters.json
            problem_name = item.name
            tex_file = item / f"{problem_name}.tex"
            
            if tex_file.exists():
                tex_files.append(tex_file)
                print(f"📁 Found problem: {problem_name} -> {tex_file}")
            else:
                print(f"⚠️ No .tex file found in {item}")
    
    # Sort by filename for consistent ordering
    tex_files.sort(key=lambda x: x.name)
    return tex_files

def parse_arguments():
    """Parse command line arguments for subset selection."""
    parser = argparse.ArgumentParser(description='Process LinearizeLLM LaTeX files with LinearizeLLMWorkflow')
    parser.add_argument('--all', action='store_true', help='Process all .tex files')
    parser.add_argument('--file', type=str, help='Process specific .tex file (e.g., blend_problem.tex)')
    parser.add_argument('--first', type=int, help='Process first N .tex files')
    parser.add_argument('--range', type=str, help='Process range of files (e.g., 0-2)')
    parser.add_argument('--files', type=str, help='Process multiple specific files (e.g., blend_problem.tex,diet_problem.tex)')
    parser.add_argument('--random', type=int, help='Process random sample of N files')
    parser.add_argument('--no-pause', action='store_true', help='Don\'t pause between files')
    parser.add_argument('--api-key', type=str, help='OpenAI API key (overrides environment variables)')
    parser.add_argument('--no-save', action='store_true', help='Don\'t save results to files (only show in console)')
    parser.add_argument('--results-dir', type=str, default='data/results', help='Base directory for saving results (default: data/results)')
    parser.add_argument('--model', type=str, default='o3', help='LLM model to use (e.g., gpt-4o, claude-3-opus, gemini-pro)')
    parser.add_argument('--model-config', type=str, help='Path to JSON file with custom model configuration')
    
    # Experimental arguments
    parser.add_argument('--experimental', action='store_true', help='Run experimental setup with multiple seeds')
    parser.add_argument('--seeds', type=str, help='Comma-separated list of seeds for experimental setup (e.g., 42,123,456)')
    parser.add_argument('--use-default-api-key', action='store_true', help='Use default API key without prompting in experimental mode')
    
    return parser.parse_args()

def main():
    """
    Main function to process LinearizeLLM LaTeX files.
    """
    print("="*80)
    print("LINEARIZELLM LATEX FILES PROCESSING WITH LINEARIZELLMWORKFLOW")
    print("="*80)
    
    # Parse command line arguments
    args = parse_arguments()
    
    # Handle model configuration first to determine provider
    llm_model = args.model
    if args.model_config:
        try:
            with open(args.model_config, 'r') as f:
                config_data = json.load(f)
            from src.utils.llm_model_manager import LLMConfig
            llm_model = LLMConfig(**config_data)
            print(f"✅ Loaded custom model configuration from: {args.model_config}")
        except Exception as e:
            print(f"❌ Failed to load model configuration: {str(e)}")
            return
    
    # Determine provider for API key setting
    provider = "openai"  # default
    if llm_model:
        if isinstance(llm_model, str):
            if llm_model.startswith("gpt-") or llm_model in ["o3"]:
                provider = "openai"
            elif llm_model.startswith("claude-"):
                provider = "anthropic"
            elif llm_model.startswith("gemini-"):
                provider = "google"
        else:
            # If it's an LLMConfig object
            provider = llm_model.provider
    
    # Set API key if provided via command line (provider-aware)
    if args.api_key:
        env_var_name = f"{provider.upper()}_API_KEY"
        os.environ[env_var_name] = args.api_key
        print(f"✅ {provider.capitalize()} API key set from command line argument")
    
    # Handle model configuration
    if not args.model_config:
        # Check if user wants to select model interactively
        if not args.model or args.model == 'o3':
            try:
                from src.utils.model_selector import select_model_interactive, setup_model_with_api_key
                
                print("\n🤖 Would you like to select a different model?")
                choice = input("Enter 'y' to select model, or press Enter to use default (o3): ").strip().lower()
                
                if choice == 'y':
                    llm_model = select_model_interactive()
                    llm_model = setup_model_with_api_key(llm_model)
                    print(f"✅ Selected model: {llm_model}")
                else:
                    print(f"🤖 Using default LLM model: {args.model}")
            except ImportError:
                print(f"🤖 Using LLM model: {args.model}")
            except Exception as e:
                print(f"⚠️ Model selection failed: {e}")
                print(f"🤖 Using LLM model: {args.model}")
        else:
            print(f"🤖 Using LLM model: {args.model}")
    
            # Path to LinearizeLLM instances directory
    instances_dir = Path("data/LinearizeLLM_data/instances_linearizellm")
    
    if not instances_dir.exists():
        print(f"❌ Instances directory not found: {instances_dir}")
        return
    
    # Find all .tex files
    tex_files = load_linearizellm_tex_files(instances_dir)
    
    if not tex_files:
        print(f"❌ No .tex files found in {instances_dir}")
        return
    
    print(f"Found {len(tex_files)} problems in LinearizeLLM instances")
    print(f"Problems: {[f.parent.name for f in tex_files]}")
    
    files_to_process = []
    
    # Check if command line arguments were provided
    if any([args.all, args.file, args.first, args.range, args.files, args.random, args.experimental]):
        # Use command line arguments
        if args.all:
            files_to_process = tex_files
            print(f"Processing all {len(files_to_process)} .tex files...")
        
        elif args.file:
            file_name = args.file.strip("'\"")  # Strip quotes from command line argument
            # Look for the problem subfolder
            problem_dir = instances_dir / file_name
            target_file = problem_dir / f"{file_name}.tex"
            if target_file.exists():
                files_to_process = [target_file]
                print(f"Processing problem {file_name}...")
            else:
                print(f"❌ Problem {file_name} not found")
                print(f"Available problems: {[f.parent.name for f in tex_files]}")
                return
        
        elif args.first:
            files_to_process = tex_files[:args.first]
            print(f"Processing first {len(files_to_process)} .tex files...")
        
        elif args.range:
            try:
                if '-' in args.range:
                    start, end = map(int, args.range.split('-'))
                    files_to_process = tex_files[start:end+1]
                    print(f"Processing files {start}-{end}: {[f.name for f in files_to_process]}")
                else:
                    print("❌ Invalid range format. Use format: 0-2")
                    return
            except ValueError:
                print("❌ Invalid range format")
                return
        
        elif args.files:
            try:
                problem_names = [name.strip().strip("'\"") for name in args.files.split(',')]  # Strip quotes
                for name in problem_names:
                    problem_dir = instances_dir / name
                    target_file = problem_dir / f"{name}.tex"
                    if target_file.exists():
                        files_to_process.append(target_file)
                    else:
                        print(f"⚠️ Problem {name} not found, skipping...")
                print(f"Processing problems: {[f.parent.name for f in files_to_process]}")
            except ValueError:
                print("❌ Invalid format")
                return
        
        elif args.random:
            import random
            n = args.random
            if n > len(tex_files):
                print(f"⚠️ Requested {n} files, but only {len(tex_files)} available. Processing all.")
                files_to_process = tex_files
            else:
                files_to_process = random.sample(tex_files, n)
                files_to_process.sort(key=lambda x: x.name)
            print(f"Processing {len(files_to_process)} random files: {[f.name for f in files_to_process]}")
        
        elif args.experimental:
            # Experimental setup via command line
            print(f"\n🔬 EXPERIMENTAL SETUP (Command Line)")
            print("="*50)
            
            # Use all files for experimental setup
            files_to_process = tex_files
            print(f"✅ Using all {len(files_to_process)} problems for experiment")
            
            # Run experimental setup
            print(f"\n🚀 Starting experimental setup...")
            try:
                from src.scripts.experimental_runner import run_experimental_setup
                run_experimental_setup(files_to_process, instances_dir, llm_model, args)
                return  # Exit after experimental setup
            except ImportError as e:
                print(f"❌ Failed to import experimental runner: {e}")
                print("Falling back to normal processing...")
            except Exception as e:
                print(f"❌ Experimental setup failed: {e}")
                print("Falling back to normal processing...")
    
    else:
        # Interactive mode - ask user which files to process
        print("\nSubset Selection Options:")
        print("1. Process all .tex files")
        print("2. Process specific .tex file")
        print("3. Process first N .tex files")
        print("4. Process range of files (e.g., 0-2)")
        print("5. Process multiple specific files (e.g., blend_problem.tex,diet_problem.tex)")
        print("6. Process random sample of N files")
        print("7. Select Experimental setup (multiple seeds, robust evaluation)")
        
        choice = input("\nEnter your choice (1-7): ").strip()
        
        if choice == "1":
            files_to_process = tex_files
            print(f"Processing all {len(files_to_process)} .tex files...")
        
        elif choice == "2":
            print(f"Available problems: {[f.parent.name for f in tex_files]}")
            problem_name = input("Enter problem name: ").strip().strip("'\"")  # Strip quotes
            problem_dir = instances_dir / problem_name
            target_file = problem_dir / f"{problem_name}.tex"
            if target_file.exists():
                files_to_process = [target_file]
                print(f"Processing problem {problem_name}...")
            else:
                print(f"❌ Problem {problem_name} not found")
                return
        
        elif choice == "3":
            try:
                n = int(input("Enter number of files to process: ").strip())
                files_to_process = tex_files[:n]
                print(f"Processing first {len(files_to_process)} .tex files...")
            except ValueError:
                print("❌ Invalid number")
                return
        
        elif choice == "4":
            range_input = input("Enter range (e.g., 0-2): ").strip()
            try:
                if '-' in range_input:
                    start, end = map(int, range_input.split('-'))
                    files_to_process = tex_files[start:end+1]
                    print(f"Processing files {start}-{end}: {[f.name for f in files_to_process]}")
                else:
                    print("❌ Invalid range format. Use format: 0-2")
                    return
            except ValueError:
                print("❌ Invalid range format")
                return
        
        elif choice == "5":
            print(f"Available problems: {[f.parent.name for f in tex_files]}")
            problems_input = input("Enter problem names separated by commas (e.g., aircraft_problem,diet_problem): ").strip()
            try:
                problem_names = [name.strip().strip("'\"") for name in problems_input.split(',')]  # Strip quotes
                for name in problem_names:
                    problem_dir = instances_dir / name
                    target_file = problem_dir / f"{name}.tex"
                    if target_file.exists():
                        files_to_process.append(target_file)
                    else:
                        print(f"⚠️ Problem {name} not found, skipping...")
                print(f"Processing problems: {[f.parent.name for f in files_to_process]}")
            except ValueError:
                print("❌ Invalid format")
                return
        
        elif choice == "6":
            try:
                import random
                n = int(input("Enter number of random files to process: ").strip())
                if n > len(tex_files):
                    print(f"⚠️ Requested {n} files, but only {len(tex_files)} available. Processing all.")
                    files_to_process = tex_files
                else:
                    files_to_process = random.sample(tex_files, n)
                    files_to_process.sort(key=lambda x: x.name)
                print(f"Processing {len(files_to_process)} random files: {[f.name for f in files_to_process]}")
            except ValueError:
                print("❌ Invalid number")
                return
        
        elif choice == "7":
            # Experimental setup with multiple seeds
            print(f"\n🔬 EXPERIMENTAL SETUP")
            print("="*50)
            print(f"Available problems: {[f.parent.name for f in tex_files]}")
            
            # Ask user to select problems for experiment
            print("\nProblem Selection Options:")
            print("1. Use all problems")
            print("2. Select specific problems")
            print("3. Select first N problems")
            print("4. Select random N problems")
            
            problem_choice = input("\nEnter problem selection choice (1-4): ").strip()
            
            if problem_choice == "1":
                files_to_process = tex_files
                print(f"✅ Selected all {len(files_to_process)} problems for experiment")
            
            elif problem_choice == "2":
                print(f"Available problems: {[f.parent.name for f in tex_files]}")
                problems_input = input("Enter problem names separated by commas: ").strip()
                try:
                    problem_names = [name.strip().strip("'\"") for name in problems_input.split(',')]
                    for name in problem_names:
                        problem_dir = instances_dir / name
                        target_file = problem_dir / f"{name}.tex"
                        if target_file.exists():
                            files_to_process.append(target_file)
                        else:
                            print(f"⚠️ Problem {name} not found, skipping...")
                    print(f"✅ Selected {len(files_to_process)} problems for experiment")
                except ValueError:
                    print("❌ Invalid format")
                    return
            
            elif problem_choice == "3":
                try:
                    n = int(input("Enter number of problems to select: ").strip())
                    files_to_process = tex_files[:n]
                    print(f"✅ Selected first {len(files_to_process)} problems for experiment")
                except ValueError:
                    print("❌ Invalid number")
                    return
            
            elif problem_choice == "4":
                try:
                    import random
                    n = int(input("Enter number of random problems to select: ").strip())
                    if n > len(tex_files):
                        print(f"⚠️ Requested {n} problems, but only {len(tex_files)} available. Selecting all.")
                        files_to_process = tex_files
                    else:
                        files_to_process = random.sample(tex_files, n)
                        files_to_process.sort(key=lambda x: x.name)
                    print(f"✅ Selected {len(files_to_process)} random problems for experiment")
                except ValueError:
                    print("❌ Invalid number")
                    return
            
            else:
                print("❌ Invalid choice")
                return
            
            # Run experimental setup
            if files_to_process:
                print(f"\n🚀 Starting experimental setup with {len(files_to_process)} problems...")
                try:
                    from src.scripts.experimental_runner import run_experimental_setup
                    run_experimental_setup(files_to_process, instances_dir, llm_model, args)
                    return  # Exit after experimental setup
                except ImportError as e:
                    print(f"❌ Failed to import experimental runner: {e}")
                    print("Falling back to normal processing...")
                except Exception as e:
                    print(f"❌ Experimental setup failed: {e}")
                    print("Falling back to normal processing...")
            else:
                print("❌ No problems selected for experiment")
                return
        
        else:
            print("❌ Invalid choice")
            return
    
    if not files_to_process:
        print("❌ No files selected for processing")
        return
    
    # Process each selected file
    results_summary = []
    
    for i, tex_file in enumerate(files_to_process, 1):
        print(f"\n" + "="*100)
        print(f"PROCESSING PROBLEM {tex_file.parent.name} ({i}/{len(files_to_process)})")
        print("="*100)
        
        try:
            # Get problem ID from folder name
            problem_id = tex_file.parent.name
            
            print(f"📄 Processing LaTeX file: {tex_file.name}")
            print(f"📁 File path: {tex_file}")
            print(f"📊 Problem ID: {problem_id}")
            
            # Initialize workflow
            workflow = LinearizeLLMWorkflow(
                tex_path=str(tex_file),
                problem_id=problem_id,
                save_results=not args.no_save,
                results_base_dir=args.results_dir,
                llm_model=llm_model
            )
            
            # Execute workflow
            print(f"\n🚀 Starting LinearizeLLMWorkflow for {tex_file.name}...")
            results = workflow.run(verbose=True)
            
            # Store summary
            summary = {
                'problem_id': problem_id,
                'problem_name': tex_file.parent.name,
                'success': results.get('optimization_results', {}).get('success', False),
                'steps_completed': len([k for k in results.keys() if k != 'error']),
                'final_status': None
            }
            
            if results.get('optimization_results', {}).get('success'):
                opt_results = results['optimization_results']['optimization_results']
                summary['final_status'] = opt_results['status']
                summary['objective_value'] = opt_results.get('objective_value')
                summary['variables_count'] = len(opt_results.get('variables', {}))
            elif results.get('optimization_results', {}).get('success') == False:
                summary['error'] = results['optimization_results'].get('error', 'Unknown optimization error')
            else:
                summary['error'] = 'No optimization results found'
            
            results_summary.append(summary)
            
            print(f"\n✅ Problem {tex_file.parent.name} processing completed")
            
        except Exception as e:
            print(f"❌ Error processing problem {tex_file.parent.name}: {str(e)}")
            results_summary.append({
                'problem_id': tex_file.parent.name,
                'problem_name': tex_file.parent.name,
                'success': False,
                'error': str(e)
            })
        
        # Pause between files (except for the last one)
        # Skip pause if --no-pause is set or if running in experimental mode
        if i < len(files_to_process) and not args.no_pause and not args.experimental:
            input(f"\nPress Enter to continue to next file...")
    
    # Final summary
    print(f"\n" + "="*100)
    print("FINAL SUMMARY - LINEARIZELLM FILES PROCESSING")
    print("="*100)
    
    successful = sum(1 for r in results_summary if r['success'])
    total = len(results_summary)
    
    print(f"📊 Overall Results: {successful}/{total} files processed successfully")
    
    for result in results_summary:
        status_icon = "✅" if result['success'] else "❌"
        if result['success'] and 'final_status' in result:
            print(f"{status_icon} {result['problem_name']}: {result['final_status']} "
                  f"(Objective: {result.get('objective_value', 'N/A')}, "
                  f"Variables: {result.get('variables_count', 'N/A')})")
        elif result['success']:
            print(f"{status_icon} {result['problem_name']}: Completed {result['steps_completed']} steps")
        else:
            print(f"{status_icon} {result['problem_name']}: {result.get('error', 'Unknown error')}")
    
    print("="*100)

if __name__ == "__main__":
    main() 