#!/usr/bin/env python3

import os
import sys
import json
import time
import logging
from pathlib import Path
from config import *

def setup_logging(db_id, output_dir):
    """Setup logging to both console and file for the specific database"""
    # Create log directory
    log_dir = Path(output_dir) / db_id / "logs"
    log_dir.mkdir(parents=True, exist_ok=True)
    
    # Create log file path
    log_file = log_dir / f"{db_id}_generation.log"
    
    # Setup logging configuration
    logging.basicConfig(
        level=logging.INFO,
        format='%(message)s',
        handlers=[
            logging.FileHandler(log_file, mode='w', encoding='utf-8'),
            logging.StreamHandler(sys.stdout)
        ]
    )
    
    return logging.getLogger()

def load_tables_from_dir(path):
    """Load Spider database schemas from tables.json"""
    logging.info(f"Loading Spider database schemas: {path}")
    tables_path = None
    for root, _, files in os.walk(path):
        if "tables.json" in files:
            tables_path = os.path.join(root, "tables.json")
            break

    if not tables_path:
        logging.error("ERROR: tables.json not found")
        sys.exit(1)

    with open(tables_path, "r", encoding="utf-8") as f:
        tables_data = json.load(f)
    return tables_data

def setup_iteration_directory(output_dir, db_id, iteration):
    """Setup directory structure for current iteration"""
    base_dir = Path(output_dir) / db_id / "schema_cache"
    iteration_dir = base_dir / f"iteration_{iteration}"
    current_dir = base_dir / "current"
    latest_dir = base_dir / "latest"
    
    # Create iteration directory
    iteration_dir.mkdir(parents=True, exist_ok=True)
    
    # Update symlinks
    if current_dir.exists():
        current_dir.unlink()
    if latest_dir.exists():
        latest_dir.unlink()
    
    # Create new symlinks
    current_dir.symlink_to(f"iteration_{iteration}", target_is_directory=True)
    latest_dir.symlink_to(f"iteration_{iteration}", target_is_directory=True)
    
    return iteration_dir

def save_iteration_state(iteration_dir, schema_sql, data_sql, state_info, summary_text, data_dictionary=None, business_configuration_logic=None):
    """Save iteration state to files with atomic operations"""
    # Atomic write function
    def atomic_write(filepath, content, mode='w'):
        temp_path = filepath.with_suffix(filepath.suffix + '.tmp')
        with open(temp_path, mode, encoding='utf-8') as f:
            f.write(content)
        temp_path.rename(filepath)
    
    # Save all files atomically
    atomic_write(iteration_dir / "schema.sql", schema_sql)
    atomic_write(iteration_dir / "data.sql", data_sql)
    atomic_write(iteration_dir / "state.json", json.dumps(state_info, indent=2, ensure_ascii=False))
    atomic_write(iteration_dir / "summary.txt", summary_text)
    
    # Save data dictionary if provided
    if data_dictionary:
        atomic_write(iteration_dir / "data_dictionary.json", json.dumps(data_dictionary, indent=2, ensure_ascii=False))
    
    # Save business configuration logic if provided
    if business_configuration_logic:
        atomic_write(iteration_dir / "business_configuration_logic.json", json.dumps(business_configuration_logic, indent=2, ensure_ascii=False))

def load_current_state(output_dir, db_id):
    """Load current state with fallback mechanism"""
    base_dir = Path(output_dir) / db_id / "schema_cache"
    current_dir = base_dir / "current"
    
    if not current_dir.exists():
        return None, None, None, None, None
    
    try:
        # Try to load from current
        with open(current_dir / "state.json", 'r', encoding='utf-8') as f:
            state_info = json.load(f)
        with open(current_dir / "schema.sql", 'r', encoding='utf-8') as f:
            schema_sql = f.read()
        with open(current_dir / "data.sql", 'r', encoding='utf-8') as f:
            data_sql = f.read()
        
        # Load data dictionary if exists
        data_dictionary = None
        data_dict_path = current_dir / "data_dictionary.json"
        if data_dict_path.exists():
            with open(data_dict_path, 'r', encoding='utf-8') as f:
                data_dictionary = json.load(f)
        
        # Load business configuration logic if exists
        business_configuration_logic = None
        business_config_path = current_dir / "business_configuration_logic.json"
        if business_config_path.exists():
            with open(business_config_path, 'r', encoding='utf-8') as f:
                business_configuration_logic = json.load(f)
        
        return state_info, schema_sql, data_sql, data_dictionary, business_configuration_logic
    
    except Exception as e:
        logging.warning(f"    WARNING: Failed to load current state: {e}")
        return None, None, None, None, None

def save_alternating_dataset_instance(output_dir, db_id, table_info, complete_documentation, 
                                   problem_description, mathematical_solution,
                                   final_or_analysis, final_implementation, 
                                   iteration_history, debug_history, verification_history,
                                   expected_alternating_formulation, final_verification_result, consistency_score,
                                   solver_results, solver_codes, triple_expert_result, solver_analysis_result):
    """Save alternating optimization dataset instance with enhanced solver analysis and incremental intelligent verification"""
    instance_dir = os.path.join(output_dir, db_id)
    os.makedirs(instance_dir, exist_ok=True)
    
    # Save debug prompts
    save_debug_prompts(output_dir, db_id, debug_history)
    
    # 1. Main structured markdown file (complete with enhanced solver sections)
    with open(os.path.join(instance_dir, "problem_solution_description.md"), "w", encoding="utf-8") as f:
        f.write(complete_documentation)
    
    # 2. Separate problem description file (sections 1-3 only)
    with open(os.path.join(instance_dir, "problem_description.md"), "w", encoding="utf-8") as f:
        f.write(problem_description)
    
    # 3. Mathematical solution file with enhanced solver sections (sections 4-8)
    with open(os.path.join(instance_dir, "mathematical_solution.md"), "w", encoding="utf-8") as f:
        f.write(mathematical_solution)
    
    # 4. Technical analysis files
    with open(os.path.join(instance_dir, "or_analysis.json"), "w", encoding="utf-8") as f:
        json.dump(final_or_analysis, f, indent=2, ensure_ascii=False)
    
    # 5. Alternating optimization iteration tracking
    with open(os.path.join(instance_dir, "alternating_iteration_history.json"), "w", encoding="utf-8") as f:
        json.dump(iteration_history, f, indent=2, ensure_ascii=False)
    
    # 6. Enhanced Mathematical verification results with incremental intelligent diagnosis
    verification_summary = {
        "final_consistency_score": consistency_score,
        "verification_threshold": VERIFICATION_THRESHOLD,
        "passes_verification": consistency_score >= VERIFICATION_THRESHOLD,
        "verification_attempts": len(verification_history),
        "verification_history": verification_history,
        "expected_alternating_formulation": expected_alternating_formulation,
        "final_verification_result": final_verification_result,
        "incremental_intelligent_verification_used": True,
        "targeted_regeneration_capability": "incremental_problem_description_and_mathematical_solution",
        "modification_strategies": ["incremental", "section_specific", "complete_rewrite"]
    }
    
    with open(os.path.join(instance_dir, "mathematical_verification.json"), "w", encoding="utf-8") as f:
        json.dump(verification_summary, f, indent=2, ensure_ascii=False)
    
    # 7. Enhanced solver execution results
    solver_summary = {
        "solver_timeout": SOLVER_TIMEOUT,
        "optimization_environment": OPTIM_VENV_PATH,
        "execution_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
        "max_solver_retry_attempts": MAX_SOLVER_RETRY_ATTEMPTS,
        "solver_consistency_tolerance": SOLVER_CONSISTENCY_TOLERANCE,
        "results": solver_results,
        "generated_codes": {
            solver: codes.get("extracted_code", "")
            for solver, codes in solver_codes.items()
        },
        "execution_summary": {
            "gurobipy_status": solver_results.get("gurobipy", {}).get("status", "not_executed"),
            "docplex_status": solver_results.get("docplex", {}).get("status", "not_executed"),
            "pyomo_status": solver_results.get("pyomo", {}).get("status", "not_executed"),
            "gurobipy_optimal_value": solver_results.get("gurobipy", {}).get("optimal_value"),
            "docplex_optimal_value": solver_results.get("docplex", {}).get("optimal_value"),
            "pyomo_optimal_value": solver_results.get("pyomo", {}).get("optimal_value"),
            "gurobipy_retry_attempt": solver_results.get("gurobipy", {}).get("retry_attempt", "N/A"),
            "docplex_retry_attempt": solver_results.get("docplex", {}).get("retry_attempt", "N/A"),
            "pyomo_retry_attempt": solver_results.get("pyomo", {}).get("retry_attempt", "N/A"),
            "solvers_consistent": check_solver_consistency(solver_results),
            "retry_occurred": any(solver_results.get(solver, {}).get("retry_attempt") for solver in ["gurobipy", "docplex", "pyomo"])
        },
        "or_expert_analysis": solver_analysis_result
    }
    
    with open(os.path.join(instance_dir, "solver_execution_results.json"), "w", encoding="utf-8") as f:
        json.dump(solver_summary, f, indent=2, ensure_ascii=False)
    
    # 8. Enhanced Solver Analysis Results
    with open(os.path.join(instance_dir, "solver_analysis_results.json"), "w", encoding="utf-8") as f:
        json.dump(solver_analysis_result, f, indent=2, ensure_ascii=False)
    
    # 9. Individual solver code files
    for solver_type, codes in solver_codes.items():
        code = codes.get("extracted_code", "")
        if code:
            with open(os.path.join(instance_dir, f"{solver_type}_code.py"), "w", encoding="utf-8") as f:
                f.write(code)
    
    # 10. Expected alternating formulation (for reference)
    with open(os.path.join(instance_dir, "expected_alternating_formulation.json"), "w", encoding="utf-8") as f:
        json.dump(expected_alternating_formulation, f, indent=2, ensure_ascii=False)
    
    # 11. Triple Expert Data Generation Results
    with open(os.path.join(instance_dir, "triple_expert_data_generation.json"), "w", encoding="utf-8") as f:
        json.dump(triple_expert_result, f, indent=2, ensure_ascii=False)
    
    # 12. Schema evolution tracking with business configuration
    schema_evolution = {"iterations": []}
    for hist in iteration_history:
        if hist.get("type") == "data_engineer":
            iteration = hist.get("iteration", 0)
            implementation = hist.get("implementation", {})
            schema_result = implementation.get("schema_adjustment_decisions", {})
            business_config = implementation.get("business_configuration_logic_updates", {})
            
            schema_evolution["iterations"].append({
                "iteration": iteration,
                "tables_created": [t.get("table_name", "") for t in schema_result.get("tables_to_create", [])],
                "tables_modified": [t.get("table_name", "") for t in schema_result.get("tables_to_modify", [])],
                "tables_deleted": [t.get("table_name", "") for t in schema_result.get("tables_to_delete", [])],
                "business_configuration_parameters": len(business_config.get("configuration_parameters", {})),
                "implementation_summary": implementation.get("implementation_summary", "")
            })
    
    with open(os.path.join(instance_dir, "schema_evolution.json"), "w", encoding="utf-8") as f:
        json.dump(schema_evolution, f, indent=2, ensure_ascii=False)
    
    logging.info(f"    Dataset instance saved successfully to: {instance_dir}")

def check_solver_consistency(solver_results):
    """Check if at least two solvers succeeded with consistent values (for summary reporting)"""
    
    # Collect successful solver results
    successful_results = []
    
    for solver in ["gurobipy", "docplex", "pyomo"]:
        result = solver_results.get(solver, {})
        status = result.get("status", "unknown")
        optimal_value = result.get("optimal_value")
        
        # Only consider solvers that succeeded and have optimal values
        if status == "optimal" and optimal_value is not None:
            successful_results.append(optimal_value)
    
    # Need at least 2 successful solvers
    if len(successful_results) < 2:
        return False
    
    # Check if any two values are consistent
    for i in range(len(successful_results)):
        for j in range(i + 1, len(successful_results)):
            if abs(successful_results[i] - successful_results[j]) <= 1e-6:
                return True
    
    return False

def save_debug_prompts(output_dir, db_id, debug_history):
    """Save all debug prompts and responses organized by iteration"""
    debug_dir = os.path.join(output_dir, db_id, "debug_prompts")
    os.makedirs(debug_dir, exist_ok=True)
    
    for i, debug_info in enumerate(debug_history):
        iteration = debug_info.get('iteration', 'unknown')
        debug_type = debug_info.get('type', 'unknown')
        
        # Create descriptive filename with sequence number
        if isinstance(iteration, int):
            filename = f"{str(iteration).zfill(2)}_{debug_type}_{str(i).zfill(3)}.txt"
        else:
            filename = f"{iteration}_{debug_type}_{str(i).zfill(3)}.txt"
        
        with open(os.path.join(debug_dir, filename), "w", encoding="utf-8") as f:
            f.write(f"Iteration {iteration} - {debug_type.upper()}\n")
            f.write(f"Sequence: {i + 1}\n")
            f.write(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
            f.write(f"Prompt:\n{debug_info.get('prompt', 'No prompt available')}\n\n")
            f.write(f"Response:\n{debug_info.get('response', 'No response available')}\n")