#!/usr/bin/env python3
"""
Script to run ReX with different parameter combinations by modifying TOML files.
This script performs a nested loop:
- Outer loop: Changes minimum_confidence_threshold in the TOML file
- Inner loop: Changes seed in the TOML file
For each combination, it executes the specified Python script with the modified TOML file.
"""

import os
import sys
import subprocess
import argparse
import toml
import shutil
import time
from pathlib import Path
from tqdm import tqdm
import json


def load_toml_config(toml_path):
    """
    Load TOML configuration file.
    
    Args:
        toml_path (str): Path to the TOML file
        
    Returns:
        dict: TOML configuration as dictionary
    """
    try:
        with open(toml_path, 'r') as f:
            config = toml.load(f)
        return config
    except Exception as e:
        print(f"Error loading TOML file {toml_path}: {e}")
        sys.exit(1)


def save_toml_config(config, toml_path):
    """
    Save TOML configuration to file.
    
    Args:
        config (dict): Configuration dictionary
        toml_path (str): Path to save the TOML file
    """
    try:
        with open(toml_path, 'w') as f:
            toml.dump(config, f)
    except Exception as e:
        print(f"Error saving TOML file {toml_path}: {e}")
        sys.exit(1)


def modify_toml_config(config, threshold, seed):
    """
    Modify TOML configuration with new threshold and seed values.
    
    Args:
        config (dict): Original configuration
        threshold (float): New minimum confidence threshold
        seed (int): New seed value
        
    Returns:
        dict: Modified configuration
    """
    # Create a deep copy to avoid modifying the original
    modified_config = config.copy()
    
    # Update minimum confidence threshold
    if 'explanation' in modified_config:
        modified_config['explanation']['minimum_confidence_threshold'] = threshold
    else:
        modified_config['explanation'] = {'minimum_confidence_threshold': threshold}
    
    # Update seed
    if 'rex' in modified_config:
        modified_config['rex']['seed'] = seed
    else:
        modified_config['rex'] = {'seed': seed}
    
    return modified_config


def execute_python_script(script_path, toml_path, output_dir, db_path, threshold, seed, additional_args=None):
    """
    Execute the Python script with the modified TOML file.
    
    Args:
        script_path (str): Path to the Python script to execute
        toml_path (str): Path to the TOML configuration file
        output_dir (str): Output directory for results
        threshold (float): Current threshold value (for logging)
        seed (int): Current seed value (for logging)
        additional_args (list): Additional command line arguments
        
    Returns:
        bool: True if execution was successful, False otherwise
    """
    # Build command
    cmd = ["python", script_path, "--config", toml_path, "--output_dir", output_dir, "--database", db_path]
    
    # Add additional arguments if provided
    if additional_args:
        cmd.extend(additional_args)
    
    
    print(f"Executing: {' '.join(cmd)}")
    
    try:
        # Execute the script
        start_time = time.time()
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            cwd=os.getcwd()
        )
        execution_time = time.time() - start_time
        
        # Log the execution
        log_file = os.path.join(output_dir, "execution_log.txt")
        with open(log_file, 'w') as f:
            f.write(f"Command: {' '.join(cmd)}\n")
            f.write(f"Execution time: {execution_time:.2f} seconds\n")
            f.write(f"Return code: {result.returncode}\n")
            f.write(f"Threshold: {threshold}\n")
            f.write(f"Seed: {seed}\n\n")
            f.write("STDOUT:\n")
            f.write(result.stdout)
            f.write("\nSTDERR:\n")
            f.write(result.stderr)
        
        if result.returncode == 0:
            print(f"✅ Success: threshold={threshold}, seed={seed} (took {execution_time:.2f}s)")
            return True
        else:
            print(f"❌ Failed: threshold={threshold}, seed={seed} (return code: {result.returncode})")
            print(f"Error: {result.stderr}")
            return False
            
    except Exception as e:
        print(f"❌ Exception: threshold={threshold}, seed={seed} - {e}")
        return False


def run_parameter_combinations(toml_path, script_path, output_dir, model_name, thresholds, seeds, additional_args=None):
    """
    Run the script with all combinations of thresholds and seeds.
    
    Args:
        toml_path (str): Path to the original TOML file
        script_path (str): Path to the Python script to execute
        output_dir (str): Base output directory
        model_name (str): Name of the model
        thresholds (list): List of threshold values to test
        seeds (list): List of seed values to test
        additional_args (list): Additional command line arguments
        
    Returns:
        dict: Summary of results
    """
    # Load original TOML configuration
    original_config = load_toml_config(toml_path)
    
    # Create temporary TOML file for modifications
    temp_toml_path = toml_path.replace('.toml', '_temp.toml')
    
    # Results tracking
    results = {
        'total_combinations': len(thresholds) * len(seeds),
        'successful': 0,
        'failed': 0,
        'combinations': []
    }
    
    print(f"Starting parameter sweep with {len(thresholds)} thresholds and {len(seeds)} seeds")
    print(f"Total combinations: {results['total_combinations']}")
    print(f"Thresholds: {thresholds}")
    print(f"Seeds: {seeds}")
    print("-" * 60)
    
    # Nested loop: outer loop for thresholds, inner loop for seeds
    for threshold in tqdm(thresholds, desc="Thresholds"):
        for seed in tqdm(seeds, desc=f"Seeds (threshold={threshold})", leave=False):
            
            # Modify TOML configuration
            modified_config = modify_toml_config(original_config, threshold, seed)
            if 'masking' in model_name:
                masking_value = modified_config['rex']['mask_value']
            else:
                masking_value = None
            
            # Save modified TOML file
            save_toml_config(modified_config, temp_toml_path)

            target_dir = os.path.join(output_dir, f"Threshold_{threshold}", f"{model_name}_seed_{seed}_threshold_{threshold}") if masking_value is None else os.path.join(output_dir, f"Threshold_{threshold}", f"{model_name}_val_{masking_value}_seed_{seed}_threshold_{threshold}")
            db_path = os.path.join(target_dir, f"{model_name}_seed_{seed}_threshold_{threshold}.db") if masking_value is None else os.path.join(target_dir, f"{model_name}_val_{masking_value}_seed_{seed}_threshold_{threshold}.db")

            if os.path.exists(target_dir):
                print(f"Skipping threshold {threshold} and seed {seed} because it already exists")
                continue
            else:
                os.makedirs(target_dir)
            
            # Execute Python script
            success = execute_python_script(
                script_path, 
                temp_toml_path, 
                target_dir,
                db_path,
                threshold, 
                seed, 
                additional_args
            )
            
            # Record results
            combination_result = {
                'threshold': threshold,
                'seed': seed,
                'success': success,
                'output_dir': target_dir
            }
            results['combinations'].append(combination_result)
            
            if success:
                results['successful'] += 1
            else:
                results['failed'] += 1
    
    # Clean up temporary file
    if os.path.exists(temp_toml_path):
        os.remove(temp_toml_path)
    
    return results


def main():
    parser = argparse.ArgumentParser(description='Run ReX with different parameter combinations')
    parser.add_argument('--toml_file', type=str, required=True,
                       help='Path to the TOML configuration file to modify')
    parser.add_argument('--python_script', type=str, required=True,
                       help='Path to the Python script to execute')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Path to the output directory')
    parser.add_argument('--model_name', type=str, required=True,
                       help='Name of the model')
    parser.add_argument('--thresholds', type=float, nargs='+', 
                       default=[0, 0.1, 0.3, 0.5, 0.7, 0.9],
                       help='List of minimum confidence thresholds to test')
    parser.add_argument('--seeds', type=int, nargs='+', 
                       default=[42, 43, 44, 45],
                       help='List of seed values to test')
    parser.add_argument('--additional_args', type=str, nargs='*',
                       help='Additional command line arguments to pass to the Python script')
    
    args = parser.parse_args()
    
    # Validate inputs
    if not os.path.exists(args.toml_file):
        print(f"Error: TOML file {args.toml_file} does not exist!")
        sys.exit(1)
    
    if not os.path.exists(args.python_script):
        print(f"Error: Python script {args.python_script} does not exist!")
        sys.exit(1)
    
   
    print(f"TOML file: {args.toml_file}")
    print(f"Python script: {args.python_script}")
    print(f"Thresholds: {args.thresholds}")
    print(f"Seeds: {args.seeds}")
    if args.additional_args:
        print(f"Additional args: {args.additional_args}")
    print("-" * 60)
    
    # Run parameter combinations
    results = run_parameter_combinations(
        args.toml_file,
        args.python_script,
        args.output_dir,
        args.model_name,
        args.thresholds,
        args.seeds,
        args.additional_args
    )

if __name__ == '__main__':
    main()
