#!/usr/bin/env python3
"""
Experimental Setup for LinearizeLLM Information Scenarios

This script runs experiments across two different information scenarios:
1. No Context: Only LaTeX optimization problem
2. Partial Information: No parameter info to detection/reformulation agents

Runs across all instances with seeds 1-5 and stores results in organized folders.
"""

import os
import json
import argparse
import subprocess
import sys
from pathlib import Path
from datetime import datetime
import time
from typing import Dict, List, Any, Optional
import random

# Add src to path for imports
sys.path.append('src')

from src.core.agent_pipeline import LinearizeLLMWorkflow


class ContextExperimentSetup:
    """
    Context experiment setup for testing different information scenarios in LinearizeLLM.
    """
    
    def __init__(self, base_data_dir: str = "data/LinearizeLLM_data/instances_linearizellm", 
                 results_base_dir: str = "data/context_experiment_results",
                 llm_model: str = "gemini-2.5-flash"):
        """
        Initialize the experimental setup.
        
        Args:
            base_data_dir: Directory containing all problem instances
            results_base_dir: Base directory for storing experimental results
            llm_model: LLM model to use for experiments
        """
        self.base_data_dir = Path(base_data_dir)
        # Include LLM model in results directory path
        self.results_base_dir = Path(results_base_dir) / llm_model
        self.llm_model = llm_model
        self.seeds = [1, 2, 3, 4, 5]
        self.scenarios = {
            'no_context': 'No Context Information',
            'partial_info': 'Partial Information (No Parameters to Agents)'
        }
        
        # Create results directory structure
        self._setup_results_directory()
        
        # Get all available instances
        self.instances = self._get_available_instances()
        
        print(f"🔬 Context Experiment Setup Initialized")
        print(f"   📁 Data Directory: {self.base_data_dir}")
        print(f"   📁 Results Directory: {self.results_base_dir}")
        print(f"   🤖 LLM Model: {self.llm_model}")
        print(f"   📊 Instances Found: {len(self.instances)}")
        print(f"   🎲 Seeds: {self.seeds}")
        print(f"   🔍 Scenarios: {list(self.scenarios.keys())}")
    
    def _setup_results_directory(self):
        """Create the results directory structure."""
        self.results_base_dir.mkdir(parents=True, exist_ok=True)
        
        # Create scenario directories
        for scenario in self.scenarios.keys():
            scenario_dir = self.results_base_dir / scenario
            scenario_dir.mkdir(exist_ok=True)
            
            # Create instance directories within each scenario
            for instance in self._get_available_instances():
                instance_dir = scenario_dir / instance
                instance_dir.mkdir(exist_ok=True)
    
    def _get_available_instances(self) -> List[str]:
        """Get all available problem instances."""
        instances = []
        if self.base_data_dir.exists():
            for item in self.base_data_dir.iterdir():
                if item.is_dir() and (item / f"{item.name}.tex").exists():
                    instances.append(item.name)
        return sorted(instances)
    
    def _create_modified_workflow_class(self, scenario: str):
        """
        Create a modified workflow class based on the scenario.
        
        Args:
            scenario: The information scenario ('no_context', 'partial_info')
            
        Returns:
            Modified LinearizeLLMWorkflow class
        """
        
        class ModifiedLinearizeLLMWorkflow(LinearizeLLMWorkflow):
            def __init__(self, tex_path, problem_id=None, save_results=True, 
                         results_base_dir="data/results", llm_model="o3", scenario="partial_info"):
                super().__init__(tex_path, problem_id, save_results, results_base_dir, llm_model)
                self.scenario = scenario
            
            def run(self, verbose=True):
                """Override the run method to implement different information scenarios."""
                import json
                
                # Step 1: Load and extract LaTeX model, parameters, and decisions
                from src.core.agent_pipeline import extract_latex_model_from_linearizellm_tex
                latex_model, parameters, decisions = extract_latex_model_from_linearizellm_tex(self.tex_path)
                self.results['latex_model'] = latex_model
                self.parameters = parameters
                self.decisions = decisions
                
                # Process parameters for better code generation
                if parameters:
                    self._save_step_result('extracted_parameters', json.dumps(parameters, indent=2), 'json')
                
                self._save_step_result('latex_model', latex_model, 'tex')
                
                # Step 2: Extract Non-linear Patterns (with scenario-specific information)
                from src.core.nonlinear_detector import NonLinearPatternExtractor
                pattern_extractor = NonLinearPatternExtractor(self.llm_model)
                
                # Apply scenario-specific information filtering
                param_info, parameter_context = self._apply_scenario_filtering()
                
                # Store the LLM input for pattern detection
                pattern_detection_input = {
                    "latex_model": latex_model,
                    "parameter_context": parameter_context,
                    "param_info": param_info,
                    "scenario": self.scenario
                }
                self._save_step_result('pattern_detection_llm_input', json.dumps(pattern_detection_input, indent=2), 'json')
                
                # Extract and parse patterns
                extracted_patterns = pattern_extractor.extract_patterns(
                    latex_model, parameter_context=parameter_context, param_info=param_info
                )
                self.results['extracted_patterns'] = extracted_patterns
                self._save_step_result('extracted_patterns', extracted_patterns)
                
                parsed_patterns = pattern_extractor.parse_patterns(extracted_patterns)
                
                # Step 3: Pattern-based MIP Linearization (if needed)
                if parsed_patterns['has_nonlinearities']:
                    from src.agents.mip_coordinator import MIPCoordinator
                    coordinator = MIPCoordinator(llm_model=self.llm_model, verbose=True)
                    
                    # Store the reformulation input
                    reformulation_input = {
                        "latex_model": latex_model,
                        "extracted_patterns": extracted_patterns,
                        "problem_id": self.problem_id,
                        "parameters": self.parameters,
                        "parameter_context": parameter_context,
                        "scenario": self.scenario
                    }
                    self._save_step_result('reformulation_llm_input', json.dumps(reformulation_input, indent=2), 'json')
                    
                    # Store the raw reformulation agent output
                    reformulation_summary = coordinator.coordinate_reformulation(
                        latex_model, extracted_patterns, self.problem_id, 
                        self.parameters, parameter_context
                    )
                    self.results['reformulation_agent_raw_output'] = reformulation_summary
                    self._save_step_result('reformulation_agent_raw_output', reformulation_summary, 'tex')
                    
                    # Store the processed linearized model
                    self.results['linearized_model'] = reformulation_summary
                    self._save_step_result('linearized_model', reformulation_summary, 'tex')
                    
                    # Extract the final model for code generation
                    if "FINAL LINEARIZED MODEL WITH REFORMULATION MARKINGS:" in reformulation_summary:
                        model_for_code = reformulation_summary.split("FINAL LINEARIZED MODEL WITH REFORMULATION MARKINGS:")[1].strip()
                    elif "FINAL LINEARIZED MODEL:" in reformulation_summary:
                        model_for_code = reformulation_summary.split("FINAL LINEARIZED MODEL:")[1].strip()
                    else:
                        model_for_code = reformulation_summary
                else:
                    self.results['reformulation_agent_raw_output'] = "No reformulation needed - model is already linear"
                    self._save_step_result('reformulation_agent_raw_output', self.results['reformulation_agent_raw_output'])
                    self.results['linearized_model'] = "No linearization needed - model is already linear"
                    self._save_step_result('linearized_model', self.results['linearized_model'])
                    model_for_code = latex_model
                
                # Step 4: Code Generation (always with full information)
                from src.core.code_converter import CodeConverter
                code_converter = CodeConverter(self.llm_model)
                
                # Extract reformulation info for potential regeneration
                reformulation_info = code_converter._extract_reformulation_info(model_for_code)
                
                # Store the code generation input
                code_generation_input = {
                    "model_for_code": model_for_code,
                    "problem_id": self.problem_id,
                    "parameters": self.parameters,
                    "parameter_context": parameter_context,
                    "scenario": self.scenario
                }
                self._save_step_result('code_generation_llm_input', json.dumps(code_generation_input, indent=2), 'json')
                
                # Store the raw code generation agent output
                gurobi_code = code_converter.convert_with_parameters(
                    model_for_code,
                    self.problem_id,
                    self.parameters,  # Always pass extracted parameters for code generation
                    parameter_context=parameter_context
                )
                self.results['code_generation_agent_raw_output'] = gurobi_code
                self._save_step_result('code_generation_agent_raw_output', gurobi_code, 'py')
                
                # Store the processed code
                self.results['gurobi_code'] = gurobi_code
                self._save_step_result('gurobi_code', gurobi_code, 'py')
                
                # Step 5: Optimization Execution
                from src.core.optimization_executor import OptimizationExecutor
                optimizer = OptimizationExecutor(save_error_log=True)
                optimization_results = optimizer.execute(
                    gurobi_code,
                    self.problem_id
                )
                self.results['optimization_results'] = optimization_results
                
                # Add scenario information to results
                self.results['experiment_scenario'] = self.scenario
                self.results['experiment_timestamp'] = datetime.now().isoformat()
                
                return self.results
            
            def _apply_scenario_filtering(self):
                """
                Apply scenario-specific information filtering.
                
                Returns:
                    tuple: (param_info, parameter_context) based on scenario
                """
                import json
                
                if self.scenario == 'no_context':
                    # No context information - only raw LaTeX model
                    # No parameters.json, no decisions.json
                    return "", ""
                
                elif self.scenario == 'partial_info':
                    # Partial information - decision variables but no concrete parameter values
                    param_info = ""
                    parameter_context = ""
                    
                    if self.decisions:
                        # Only include decision variable information
                        decision_names = list(self.decisions.keys())
                        
                        parameter_context = f"""
DECISION VARIABLES: {', '.join(decision_names) if decision_names else 'None'}
"""
                        
                        # Add decision variable forms
                        parameter_context += "\nDECISION VARIABLE FORMS:\n"
                        for var, forms in self.decisions.items():
                            parameter_context += f"- {var}: {', '.join(forms)}\n"
                    
                    return param_info, parameter_context
                
                else:
                    raise ValueError(f"Unknown scenario: {self.scenario}")
        
        return ModifiedLinearizeLLMWorkflow
    
    def run_single_experiment(self, instance: str, seed: int, scenario: str, verbose: bool = True) -> Dict[str, Any]:
        """
        Run a single experiment for a specific instance, seed, and scenario.
        
        Args:
            instance: Problem instance name
            seed: Random seed
            scenario: Information scenario
            verbose: Whether to print detailed output
            
        Returns:
            Dictionary containing experiment results
        """
        # Set random seed for reproducibility
        random.seed(seed)
        
        # Create instance-specific results directory
        instance_results_dir = self.results_base_dir / scenario / instance / f"seed_{seed}"
        instance_results_dir.mkdir(parents=True, exist_ok=True)
        
        # Path to the problem file
        problem_file = self.base_data_dir / instance / f"{instance}.tex"
        
        if not problem_file.exists():
            raise FileNotFoundError(f"Problem file not found: {problem_file}")
        
        # Create modified workflow class for this scenario
        ModifiedWorkflow = self._create_modified_workflow_class(scenario)
        
        # Create and run workflow
        workflow = ModifiedWorkflow(
            tex_path=str(problem_file),
            problem_id=instance,
            save_results=True,
            results_base_dir=str(instance_results_dir),
            llm_model=self.llm_model,
            scenario=scenario
        )
        
        try:
            results = workflow.run(verbose=verbose)
            
            # Add context experiment metadata
            results['context_experiment_metadata'] = {
                'instance': instance,
                'seed': seed,
                'scenario': scenario,
                'llm_model': self.llm_model,
                'timestamp': datetime.now().isoformat(),
                'status': 'success'
            }
            
            # Save complete results
            results_file = instance_results_dir / "context_experiment_results.json"
            with open(results_file, 'w') as f:
                json.dump(results, f, indent=2, default=str)
            
            # Always create/update error file with success status
            error_result = {
                'status': 'success',
                'context_experiment_metadata': {
                    'instance': instance,
                    'seed': seed,
                    'scenario': scenario,
                    'llm_model': self.llm_model,
                    'timestamp': datetime.now().isoformat(),
                    'status': 'success'
                }
            }
            error_file = instance_results_dir / "context_experiment_error.json"
            with open(error_file, 'w') as f:
                json.dump(error_result, f, indent=2, default=str)
            
            if verbose:
                print(f"✅ Context experiment completed: {instance} - Seed {seed} - {scenario}")
                print(f"   📁 Results saved to: {results_file}")
                print(f"   📁 Error file updated: {error_file}")
            
            return results
            
        except Exception as e:
            error_result = {
                'error': str(e),
                'context_experiment_metadata': {
                    'instance': instance,
                    'seed': seed,
                    'scenario': scenario,
                    'llm_model': self.llm_model,
                    'timestamp': datetime.now().isoformat(),
                    'status': 'failed'
                }
            }
            
            # Save error results
            error_file = instance_results_dir / "context_experiment_error.json"
            with open(error_file, 'w') as f:
                json.dump(error_result, f, indent=2, default=str)
            
            # Always create/update results file with error status
            results = {
                'error': str(e),
                'context_experiment_metadata': {
                    'instance': instance,
                    'seed': seed,
                    'scenario': scenario,
                    'llm_model': self.llm_model,
                    'timestamp': datetime.now().isoformat(),
                    'status': 'failed'
                }
            }
            results_file = instance_results_dir / "context_experiment_results.json"
            with open(results_file, 'w') as f:
                json.dump(results, f, indent=2, default=str)
            
            if verbose:
                print(f"❌ Context experiment failed: {instance} - Seed {seed} - {scenario}")
                print(f"   Error: {str(e)}")
                print(f"   📁 Error saved to: {error_file}")
                print(f"   📁 Results file updated: {results_file}")
            
            return error_result
    
    def run_full_context_experiment(self, instances: Optional[List[str]] = None, 
                                   scenarios: Optional[List[str]] = None,
                                   seeds: Optional[List[int]] = None,
                                   verbose: bool = True) -> Dict[str, Any]:
        """
        Run the full context experiment suite.
        
        Args:
            instances: List of instances to run (None for all)
            scenarios: List of scenarios to run (None for all)
            seeds: List of seeds to run (None for all)
            verbose: Whether to print detailed output
            
        Returns:
            Dictionary containing summary of all experiments
        """
        if instances is None:
            instances = self.instances
        if scenarios is None:
            scenarios = list(self.scenarios.keys())
        if seeds is None:
            seeds = self.seeds
        
        total_experiments = len(instances) * len(scenarios) * len(seeds)
        
        # Track context experiment results
        context_experiment_summary = {
            'context_experiment_metadata': {
                'total_experiments': total_experiments,
                'instances': instances,
                'scenarios': scenarios,
                'seeds': seeds,
                'llm_model': self.llm_model,
                'start_time': datetime.now().isoformat()
            },
            'results': {},
            'statistics': {
                'completed': 0,
                'failed': 0,
                'total_time': 0
            }
        }
        
        start_time = time.time()
        experiment_count = 0
        
        try:
            for instance in instances:
                for scenario in scenarios:
                    for seed in seeds:
                        experiment_count += 1
                        
                        # Run the experiment
                        result = self.run_single_experiment(instance, seed, scenario, verbose=verbose)
                        
                        # Track results
                        key = f"{instance}_{scenario}_seed_{seed}"
                        context_experiment_summary['results'][key] = result
                        
                        if 'error' in result:
                            context_experiment_summary['statistics']['failed'] += 1
                        else:
                            context_experiment_summary['statistics']['completed'] += 1
                        
                        # Add delay between experiments to avoid rate limiting
                        time.sleep(1)
        
        except KeyboardInterrupt:
            print("\n⚠️ Experiment interrupted by user")
        
        finally:
            end_time = time.time()
            context_experiment_summary['statistics']['total_time'] = end_time - start_time
            context_experiment_summary['context_experiment_metadata']['end_time'] = datetime.now().isoformat()
            
            # Save context experiment summary
            summary_file = self.results_base_dir / "context_experiment_summary.json"
            with open(summary_file, 'w') as f:
                json.dump(context_experiment_summary, f, indent=2, default=str)
        
        return context_experiment_summary


def main():
    """Main function to run the context experiment setup."""
    parser = argparse.ArgumentParser(description='Run LinearizeLLM context experiment setup')
    parser.add_argument('--data-dir', type=str, default="data/LinearizeLLM_data/instances_linearizellm",
                       help='Directory containing problem instances')
    parser.add_argument('--results-dir', type=str, default="data/context_experiment_results",
                       help='Directory to store context experiment results')
    parser.add_argument('--llm-model', type=str, default="gemini-2.5-flash",
                       choices=['o3', 'gemini-2.5-flash'],
                       help='LLM model to use for context experiments (o3 or gemini-2.5-flash)')
    parser.add_argument('--instances', nargs='+', help='Specific instances to run (default: all)')
    parser.add_argument('--scenarios', nargs='+', choices=['no_context', 'partial_info'],
                       help='Specific scenarios to run (default: all)')
    parser.add_argument('--seeds', nargs='+', type=int, help='Specific seeds to run (default: 1-5)')
    parser.add_argument('--single-experiment', action='store_true',
                       help='Run a single experiment (requires --instance, --scenario, --seed)')
    parser.add_argument('--instance', type=str, help='Instance for single experiment')
    parser.add_argument('--scenario', type=str, choices=['no_context', 'partial_info'],
                       help='Scenario for single experiment')
    parser.add_argument('--seed', type=int, help='Seed for single experiment')
    parser.add_argument('--quiet', action='store_true', help='Reduce verbose output')
    
    args = parser.parse_args()
    
    # Validate single experiment arguments
    if args.single_experiment:
        if not all([args.instance, args.scenario, args.seed]):
            parser.error("--single-experiment requires --instance, --scenario, and --seed")
    
    # Create context experiment setup
    setup = ContextExperimentSetup(
        base_data_dir=args.data_dir,
        results_base_dir=args.results_dir,
        llm_model=args.llm_model
    )
    
    # Run experiments
    if args.single_experiment:
        # Run single experiment
        result = setup.run_single_experiment(
            instance=args.instance,
            seed=args.seed,
            scenario=args.scenario,
            verbose=not args.quiet
        )
        print(f"Single experiment result: {result}")
    else:
        # Run full context experiment suite
        summary = setup.run_full_context_experiment(
            instances=args.instances,
            scenarios=args.scenarios,
            seeds=args.seeds,
            verbose=not args.quiet
        )
        print(f"Context experiment suite completed. Summary: {summary}")


if __name__ == "__main__":
    main() 