#!/usr/bin/env python3
"""
Script to run the Objective Discovery pipeline.

This script provides a command-line interface for discovering objectives
from model training trajectories using various discovery methods.
"""

import argparse
import json
import os
import sys
import yaml
import random
import numpy as np
import torch
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional
import logging
from pprint import pformat

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.objectives_discovery import RandomObjectivesDiscovery, ProposedObjectivesDiscovery, StaticObjectivesDiscovery, BaseObjectivesDiscovery
from src.fixed_objectives_discovery import ObtainFixedObjectives, FixedObjectivesDiscovery
# from src.objectives_evaluator import ObjectivesEvaluator, EvaluationMetrics
from src.constants import OPENAI_API_KEY, SFT_MODELS_FILEPATHS, DATASET_NAMES_DICT, DATASET_MAX_LENGTH_DICT
from src.custom_rewards import LLMRewardFunction, RewardModelFunction


def setup_logging(output_dir: str) -> logging.Logger:
    """
    Setup comprehensive logging system that outputs to both console and file.
    
    Args:
        output_dir: Directory to save log files
    
    Returns:
        Logger instance
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Create logger
    logger = logging.getLogger('objective_discovery')
    logger.setLevel(logging.INFO)
    
    # Remove any existing handlers
    logger.handlers.clear()
    
    # Create file handler for logs.txt
    log_file = os.path.join(output_dir, 'logs.txt')
    file_handler = logging.FileHandler(log_file, mode='w', encoding='utf-8')
    file_handler.setLevel(logging.INFO)
    
    # Create console handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    
    # Create formatter with detailed format
    formatter = logging.Formatter(
        '%(asctime)s | %(levelname)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    
    # Set formatters
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)
    
    # Add handlers to logger
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    # Log initialization
    logger.info("="*80)
    logger.info("OBJECTIVE DISCOVERY LOGGING INITIALIZED")
    logger.info(f"Log file: {log_file}")
    logger.info(f"Output directory: {output_dir}")
    logger.info("="*80)
    logger.info("")
    
    return logger


def log_section_header(logger: logging.Logger, title: str, level: int = 1):
    """
    Log a formatted section header.
    
    Args:
        logger: Logger instance
        title: Section title
        level: Header level (1=main, 2=sub, 3=sub-sub)
    """
    if level == 1:
        border = "=" * 80
        logger.info(border)
        logger.info(title.center(80))
        logger.info(border)
    elif level == 2:
        border = "-" * 60
        logger.info("")
        logger.info(border)
        logger.info(f"  {title}")
        logger.info(border)
    else:
        logger.info("")
        logger.info(f"    >>> {title}")
        logger.info("")


def log_dict_pretty(logger: logging.Logger, data: Dict[str, Any], title: str = None, indent: int = 2):
    """
    Log a dictionary in a pretty-printed format.
    
    Args:
        logger: Logger instance
        data: Dictionary to log
        title: Optional title
        indent: Number of spaces to indent
    """
    if title:
        logger.info(f"{' ' * indent}{title}:")
    
    formatted = pformat(data, indent=2, width=100, compact=False)
    for line in formatted.split('\n'):
        logger.info(f"{' ' * (indent + 2)}{line}")


def log_list_pretty(logger: logging.Logger, items: List[Any], title: str = None, max_items: int = 10, indent: int = 2):
    """
    Log a list in a pretty format.
    
    Args:
        logger: Logger instance
        items: List to log
        title: Optional title
        max_items: Maximum number of items to display
        indent: Number of spaces to indent
    """
    if title:
        logger.info(f"{' ' * indent}{title} ({len(items)} items):")
    
    items_to_show = items[:max_items] if len(items) > max_items else items
    
    for i, item in enumerate(items_to_show, 1):
        if isinstance(item, dict):
            logger.info(f"{' ' * (indent + 2)}{i}. {pformat(item, indent=2, width=80)}")
        else:
            # Truncate long strings
            item_str = str(item)
            if len(item_str) > 200:
                item_str = item_str[:200] + "..."
            logger.info(f"{' ' * (indent + 2)}{i}. {item_str}")
    
    if len(items) > max_items:
        logger.info(f"{' ' * (indent + 2)}... and {len(items) - max_items} more items")


def load_config(config_path: str) -> Dict[str, Any]:
    """
    Load configuration from YAML file.
    
    Args:
        config_path: Path to the YAML configuration file
        
    Returns:
        Dictionary containing configuration parameters
    """
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def merge_args_with_config(args: argparse.Namespace, config: Dict[str, Any]) -> Dict[str, Any]:
    """
    Merge command-line arguments with configuration file.
    Command-line arguments take precedence over config file.
    
    Args:
        args: Parsed command-line arguments
        config: Configuration dictionary from YAML file
        
    Returns:
        Merged configuration dictionary
    """
    # Convert args to dictionary
    args_dict = vars(args)
    
    # Start with config as base
    merged = config.copy()
    
    # Handle evaluation config from YAML
    if 'evaluation' in config and isinstance(config['evaluation'], dict):
        # Flatten evaluation config into main config
        for eval_key, eval_value in config['evaluation'].items():
            if eval_key not in merged:  # Don't override if already at top level
                merged[eval_key] = eval_value
    
    # Override with non-None command-line arguments
    for key, value in args_dict.items():
        if value is not None:
            # Handle special cases for nested config
            if key in ['model_sequence', 'dataset_args']:
                if value:  # Only override if provided
                    merged[key] = value
            else:
                merged[key] = value
    
    return merged


def set_seed(seed: int):
    """
    Set random seeds for reproducibility.
    
    Args:
        seed: Random seed value
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def load_ground_truth_objectives(config: Dict[str, Any]) -> Optional[List[str]]:
    """
    Load ground-truth objectives from configuration.
    
    Args:
        config: Configuration dictionary
        
    Returns:
        List of ground-truth objectives or None if not provided
    """
    # Check for ground-truth objectives in config
    if 'ground_truth_objectives' in config and config['ground_truth_objectives']:
        return config['ground_truth_objectives']
    
    # Check for ground-truth file
    if 'ground_truth_file' in config and config['ground_truth_file']:
        try:
            with open(config['ground_truth_file'], 'r') as f:
                data = json.load(f)
                if isinstance(data, list):
                    return data
                elif isinstance(data, dict) and 'objectives' in data:
                    return data['objectives']
                else:
                    print(f"Warning: Invalid format in ground-truth file {config['ground_truth_file']}")
        except Exception as e:
            print(f"Warning: Failed to load ground-truth file {config['ground_truth_file']}: {e}")
    
    return None


# def evaluate_objectives(
#     predicted_objectives: List[str],
#     ground_truth_objectives: List[str],
#     config: Dict[str, Any],
#     iteration: Optional[int] = None,
#     verbose: bool = True
# ) -> Optional[EvaluationMetrics]:
#     """
#     Evaluate predicted objectives against ground-truth.
    
#     Args:
#         predicted_objectives: List of predicted/discovered objectives
#         ground_truth_objectives: List of ground-truth objectives
#         config: Configuration dictionary
#         iteration: Optional iteration number for reporting
#         verbose: Whether to print evaluation results
        
#     Returns:
#         EvaluationMetrics or None if evaluation fails
#     """
#     if not predicted_objectives or not ground_truth_objectives:
#         return None
    
#     try:
#         evaluator = ObjectivesEvaluator(
#             evaluation_model=config.get('evaluation_model', 'gpt-4o-mini'),
#             ground_truth_objectives=ground_truth_objectives,
#             predicted_objectives=predicted_objectives,
#             similarity_threshold=config.get('similarity_threshold', 7.0),
#             duplicate_threshold=config.get('duplicate_threshold', 7.0),
#             use_api=True,
#             verbose=False  # Use our own verbose control
#         )
        
#         metrics = evaluator.get_all_metrics()
        
#         if verbose:
#             iter_str = f"Iteration {iteration}" if iteration is not None else "Final"
#             print(f"\n{'='*60}")
#             print(f"EVALUATION RESULTS - {iter_str}")
#             print(f"{'='*60}")
#             print(f"Predicted objectives: {len(predicted_objectives)}")
#             print(f"Ground-truth objectives: {len(ground_truth_objectives)}")
#             print(f"True Positives: {metrics.true_positives}")
#             print(f"False Positives: {metrics.false_positives}")
#             print(f"False Negatives: {metrics.false_negatives}")
#             print(f"Precision: {metrics.precision:.3f}")
#             print(f"Recall: {metrics.recall:.3f}")
#             print(f"F1 Score: {metrics.f1_score:.3f}")
#             print(f"{'='*60}")
        
#         return metrics
        
#     except Exception as e:
#         print(f"Warning: Evaluation failed: {e}")
#         return None


def validate_config(config: Dict[str, Any]) -> bool:
    """
    Validate the configuration parameters.
    
    Args:
        config: Configuration dictionary
        
    Returns:
        True if configuration is valid
        
    Raises:
        ValueError: If configuration is invalid
    """
    # Check required fields
    required_fields = ['dataset', 'model_sequence', 'k', 'method']
    for field in required_fields:
        if field not in config:
            raise ValueError(f"Required field '{field}' missing from configuration")
    
    # Validate model sequence
    if not config['model_sequence'] or len(config['model_sequence']) < 2:
        raise ValueError("model_sequence must contain at least 2 model checkpoints")
    
    # Validate k
    if config['k'] < 1:
        raise ValueError("k (number of objectives) must be at least 1")
    
    # Validate method
    valid_methods = ['random', 'proposed', 'static', 'guided', 'obtain_fixed', 'fixed']
    if config['method'] not in valid_methods:
        raise ValueError(f"method must be one of {valid_methods}")
    
    # Check API key if using API proposer
    if config.get('use_api_proposer', True):
        if not OPENAI_API_KEY and not os.environ.get('OPENAI_API_KEY'):
            print("Warning: No OpenAI API key found. Set OPENAI_API_KEY in constants.py or environment.")
            print("Falling back to local proposer model if available.")
            config['use_api_proposer'] = False
    
    return True


def initialize_discovery_method(config: Dict[str, Any], logger: logging.Logger = None) -> BaseObjectivesDiscovery:
    """
    Initialize the appropriate discovery method based on configuration.
    
    Args:
        config: Configuration dictionary
        logger: Optional logger instance
        
    Returns:
        Initialized discovery method instance
    """
    method = config['method']
    
    # Extract common parameters
    common_params = {
        'dataset': config['dataset'],
        'model_sequence': config['model_sequence'],
        'k': config['k'],
        'verifier_epsilon_interpretable': config.get('verifier_epsilon_interpretable', 0.15),
        'verifier_epsilon_trend': config.get('verifier_epsilon_trend', 0.1),
        'verification_sample_size': config.get('verification_sample_size', 20),
        'proposer_model': config.get('proposer_model', 'gpt-4o-mini'),
        'scorer_model_name': config.get('scorer_model_name', 'gpt-4o-mini'),
        'human_scorer_models': config.get('human_scorer_models', None),
        'use_api_proposer': config.get('use_api_proposer', True),
        'device': config.get('device', 'auto'),
        'output_dir': config.get('output_dir'),  # Pass output_dir to discovery methods
        'logger': logger,  # Pass logger to discovery methods
        'multi_turn': config.get('multi_turn', False),
        'max_prompt_length': config.get('max_prompt_length', None),
        'base_model_name': config.get('base_model_name', None),
        'max_concurrent': config.get('max_concurrent', 10),
        'data_dir': config.get('data_dir', None)  # Optional data directory for datasets
    }
    
    if method == 'random':
        # Create ground truth reward function for final evaluation
        ground_truth_type = config.get('ground_truth_type', 'llm_function')
        if ground_truth_type == 'reward_model':
            # Use RewardModelFunction for ground truth
            ground_truth_reward = RewardModelFunction(
                model_name=config.get('reward_model_name'),
                device=config.get('device', 'auto'),
                max_length=config.get('reward_model_max_length', 512),
                use_quantization=config.get('use_quantization', False),
                normalize_scores=config.get('normalize_scores', True)
            )
        else:
            # Use LLMRewardFunction for ground truth (default)
            ground_truth_objectives = config.get('ground_truth_objectives')
            ground_truth_weights = config.get('ground_truth_weights', {obj: 1.0/len(ground_truth_objectives) for obj in ground_truth_objectives})
            ground_truth_reward = LLMRewardFunction(
                model_name=config.get('ground_truth_model', 'gpt-4o-mini'),
                use_api=config.get('use_api_for_ground_truth', True),
                combiner_type=config.get('ground_truth_combiner_type', 'linear'),
                objective_names=ground_truth_objectives,
                manual_weights=ground_truth_weights,
                device=config.get('device', 'auto'),
                max_length=config.get('max_length', 4096),
                use_detailed_rubric=config.get('use_detailed_rubric', True),
                dataset_type=config.get('dataset_type', 'hh'),
                normalize_scores=config.get('normalize_scores', True),
                save_dir=config.get('output_dir'),
            )

        # Add RandomObjectivesDiscovery specific parameters
        discovery = RandomObjectivesDiscovery(
            **common_params,
            ground_truth_reward=ground_truth_reward,
            samples_per_iteration=config.get('samples_per_iteration', 5),
            objectives_per_trajectory=config.get('objectives_per_trajectory', 3),
            max_iterations=config.get('max_iterations', 100),
            train_test_split_idx=config.get('train_test_split_idx', None),
            num_samples_final_eval=config.get('num_samples_final_eval', 25),
            combination_function_type=config.get('combination_function_type', 'linear_regression'),
            combination_function_params=config.get('combination_function_params', {})
        )
    elif method == 'proposed':
        # Create ground truth reward function based on type
        ground_truth_type = config.get('ground_truth_type', 'llm_function')
        if ground_truth_type == 'reward_model':
            # Use RewardModelFunction for ground truth
            ground_truth_reward = RewardModelFunction(
                model_name=config.get('reward_model_name'),
                device=config.get('device', 'auto'),
                max_length=config.get('reward_model_max_length', 512),
                use_quantization=config.get('use_quantization', False),
                normalize_scores=config.get('normalize_scores', True)  # Default to True for objective discovery
            )
        else:
            # Use LLMRewardFunction for ground truth (existing behavior)
            ground_truth_objectives = config.get('ground_truth_objectives')
            ground_truth_weights = config.get('ground_truth_weights', {obj: 1.0/len(ground_truth_objectives) for obj in ground_truth_objectives})
            ground_truth_reward = LLMRewardFunction(
                model_name=config.get('ground_truth_model', 'gpt-4o-mini'),
                use_api=config.get('use_api_for_ground_truth', True),
                combiner_type=config.get('ground_truth_combiner_type', 'linear'),
                objective_names=ground_truth_objectives,
                manual_weights=ground_truth_weights,
                device=config.get('device', 'auto'),
                max_length=config.get('max_length', 4096),
                use_detailed_rubric=config.get('use_detailed_rubric', True),
                dataset_type=config.get('dataset_type', 'hh'),
                normalize_scores=config.get('normalize_scores', True),  # Default to True for objective discovery
                save_dir=config.get('output_dir'),
            )
        # Initialize ProposedObjectivesDiscovery
        discovery = ProposedObjectivesDiscovery(
            **common_params,
            ground_truth_reward=ground_truth_reward,
            x_cand_size=config.get('x_cand_size', 100),
            x_disc_size=config.get('x_disc_size', 10),
            objectives_per_trajectory=config.get('objectives_per_trajectory', 3),
            num_parallel_trajectories=config.get('num_parallel_trajectories', 1),
            num_samples_select_best=config.get('num_samples_select_best', 20),
            num_samples_final_eval=config.get('num_samples_final_eval', 25),
            combination_function_type=config.get('combination_function_type', 'linear_regression'),
            combination_function_params=config.get('combination_function_params', {}),
            train_test_split_idx=config.get('train_test_split_idx', None),
            max_iterations=config.get('max_iterations', 50),
            use_random_sampling=config.get('use_random_sampling', False)
        )
    elif method == 'static':
        # Create ground truth reward function for final evaluation (optional)
        ground_truth_type = config.get('ground_truth_type', 'llm_function')
        ground_truth_reward = None
        # if ground_truth_type and config.get('ground_truth_objectives'):
        if ground_truth_type:
            if ground_truth_type == 'reward_model':
                # Use RewardModelFunction for ground truth
                ground_truth_reward = RewardModelFunction(
                    model_name=config.get('reward_model_name'),
                    device=config.get('device', 'auto'),
                    max_length=config.get('reward_model_max_length', 512),
                    use_quantization=config.get('use_quantization', False),
                    normalize_scores=config.get('normalize_scores', True)
                )
            else:
                # Use LLMRewardFunction for ground truth (default)
                ground_truth_objectives = config.get('ground_truth_objectives')
                ground_truth_weights = config.get('ground_truth_weights', {obj: 1.0/len(ground_truth_objectives) for obj in ground_truth_objectives})
                ground_truth_reward = LLMRewardFunction(
                    model_name=config.get('ground_truth_model', 'gpt-4o-mini'),
                    use_api=config.get('use_api_for_ground_truth', True),
                    combiner_type=config.get('ground_truth_combiner_type', 'linear'),
                    objective_names=ground_truth_objectives,
                    manual_weights=ground_truth_weights,
                    device=config.get('device', 'auto'),
                    max_length=config.get('max_length', 4096),
                    use_detailed_rubric=config.get('use_detailed_rubric', True),
                    dataset_type=config.get('dataset_type', 'hh'),
                    normalize_scores=config.get('normalize_scores', True),
                    save_dir=config.get('output_dir')
                )

        # Add StaticObjectivesDiscovery specific parameters
        discovery = StaticObjectivesDiscovery(
            **common_params,
            samples_per_discovery=config.get('samples_per_discovery', 10),
            ground_truth_reward=ground_truth_reward,
            num_samples_final_eval=config.get('num_samples_final_eval', 25),
            train_test_split_idx=config.get('train_test_split_idx', None),
            combination_function_type=config.get('combination_function_type', 'linear_regression'),
            combination_function_params=config.get('combination_function_params', {})
        )
    elif method == 'obtain_fixed':
        # Initialize ObtainFixedObjectives
        discovery = ObtainFixedObjectives(
            **common_params,
            num_parallel_trajectories=config.get('num_parallel_trajectories', 3),
            objectives_per_trajectory=config.get('objectives_per_trajectory', 3)
        )
    elif method == 'fixed':
        # Create ground truth reward function for fixed objectives method
        ground_truth_type = config.get('ground_truth_type', 'llm_function')
        if ground_truth_type == 'reward_model':
            # Use RewardModelFunction for ground truth
            ground_truth_reward = RewardModelFunction(
                model_name=config.get('reward_model_name'),
                device=config.get('device', 'auto'),
                max_length=config.get('reward_model_max_length', 512),
                use_quantization=config.get('use_quantization', False),
                normalize_scores=config.get('normalize_scores', True)
            )
        else:
            # Use LLMRewardFunction for ground truth
            ground_truth_objectives = config.get('ground_truth_objectives')
            ground_truth_weights = config.get('ground_truth_weights', {obj: 1.0/len(ground_truth_objectives) for obj in ground_truth_objectives})
            ground_truth_reward = LLMRewardFunction(
                model_name=config.get('ground_truth_model', 'gpt-4o-mini'),
                use_api=config.get('use_api_for_ground_truth', True),
                combiner_type=config.get('ground_truth_combiner_type', 'linear'),
                objective_names=ground_truth_objectives,
                manual_weights=ground_truth_weights,
                device=config.get('device', 'auto'),
                max_length=config.get('max_length', 4096),
                use_detailed_rubric=config.get('use_detailed_rubric', True),
                dataset_type=config.get('dataset_type', 'hh'),
                normalize_scores=config.get('normalize_scores', True),
                save_dir=config.get('output_dir')
            )

        # Initialize FixedObjectivesDiscovery
        discovery = FixedObjectivesDiscovery(
            **common_params,
            ground_truth_reward=ground_truth_reward,
            fixed_objectives_filepath=config.get('fixed_objectives_filepath'),
            num_fixed_objs=config.get('num_fixed_objs', 10),
            num_samples_final_eval=config.get('num_samples_final_eval', 25),
            num_samples_select_best=config.get('num_samples_select_best', 20),
            combination_function_type=config.get('combination_function_type', 'linear_regression'),
            combination_function_params=config.get('combination_function_params', {}),
            train_test_split_idx=config.get('train_test_split_idx', None),
            max_iterations=config.get('max_iterations', 50),
            # save_dir=config.get('output_dir')
        )
    elif method == 'guided':
        # Placeholder for future guided method
        raise NotImplementedError("Guided discovery method not yet implemented")
    else:
        raise ValueError(f"Unknown discovery method: {method}")
    
    return discovery


def save_results(
    objectives: List[str],
    stats: Dict[str, Any],
    config: Dict[str, Any],
    output_dir: str
):
    """
    Save discovery results to files.
    
    Args:
        objectives: List of discovered objectives
        stats: Discovery statistics
        config: Configuration used
        output_dir: Directory to save results
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Create timestamp for filenames
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Save objectives to text file
    objectives_file = os.path.join(output_dir, f"objectives_{timestamp}.txt")
    with open(objectives_file, 'w') as f:
        f.write("DISCOVERED OBJECTIVES\n")
        f.write("=" * 80 + "\n\n")
        for i, obj in enumerate(objectives, 1):
            f.write(f"{i}. {obj}\n\n")
    
    print(f"Objectives saved to: {objectives_file}")
    
    # Save full results to JSON
    results = {
        'timestamp': timestamp,
        'objectives': objectives,
        'statistics': stats,
        'configuration': config
    }
    
    json_file = os.path.join(output_dir, f"results_{timestamp}.json")
    with open(json_file, 'w') as f:
        json.dump(results, f, indent=2, default=str)
    
    print(f"Full results saved to: {json_file}")
    
    # Save summary to markdown
    md_file = os.path.join(output_dir, f"summary_{timestamp}.md")
    with open(md_file, 'w') as f:
        f.write("# Objective Discovery Results\n\n")
        f.write(f"**Date**: {timestamp}\n")
        f.write(f"**Method**: {config['method']}\n")
        f.write(f"**Dataset**: {config['dataset']}\n")
        f.write(f"**Target Objectives**: {config['k']}\n\n")
        
        f.write("## Discovered Objectives\n\n")
        for i, obj in enumerate(objectives, 1):
            f.write(f"{i}. {obj}\n")
        
        f.write("\n## Statistics\n\n")
        f.write(f"- **Objectives Found**: {stats.get('discovered_count', len(objectives))}\n")
        f.write(f"- **Total Proposals**: {stats.get('total_proposals', 'N/A')}\n")
        f.write(f"- **Acceptance Rate**: {stats.get('acceptance_rate', 0):.2%}\n")
        f.write(f"- **Iterations**: {stats.get('total_iterations', 'N/A')}\n")
        f.write(f"- **Time Elapsed**: {stats.get('time_elapsed', 0):.2f} seconds\n")
        
        if 'verification_failures' in stats:
            f.write(f"\n### Verification Failures\n\n")
            f.write(f"- **Interpretability**: {stats['verification_failures'].get('interpretability', 0)}\n")
            f.write(f"- **Trend**: {stats['verification_failures'].get('trend', 0)}\n")
    
    print(f"Summary saved to: {md_file}")


def main():
    """Main function to run objective discovery."""
    
    # Parse command-line arguments
    parser = argparse.ArgumentParser(
        description="Run objective discovery on model training trajectories",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    # Configuration file
    parser.add_argument('--config', '-c', type=str, default='configs/objective_discovery.yaml', help='Path to configuration YAML file')
    # Dataset parameters
    parser.add_argument('--dataset', type=str, help='Dataset name or path (e.g., "trl-internal-testing/tldr-preference-sft-trl-style")')
    parser.add_argument('--data_dir', type=str, help='Optional data directory for datasets that support it (e.g., "helpful-base" for HH-RLHF)')
    parser.add_argument('--multi_turn', type=bool, help='Whether the dataset is multi-turn')
    # Model sequence
    parser.add_argument('--model_sequence', nargs='+', help='List of model checkpoint paths')
    # Discovery parameters
    parser.add_argument('--k', type=int, help='Number of objectives to discover')
    parser.add_argument('--method', type=str, choices=['random', 'proposed', 'static', 'guided', 'obtain_fixed', 'fixed'], help='Discovery method to use')
    parser.add_argument('--samples_per_iteration', type=int, help='Number of prompts to sample per iteration (random method)')
    parser.add_argument('--objectives_per_trajectory', type=int, help='Number of objectives to propose per trajectory')
    parser.add_argument('--num_parallel_trajectories', type=int, help='Number of trajectories to process together in discovery prompt (proposed method)')
    parser.add_argument('--max_iterations', type=int, help='Maximum number of iterations')
    # Proposed method specific parameters
    parser.add_argument('--x_cand_size', type=int, help='Size of candidate sample set for informative sample selection (proposed method)')
    parser.add_argument('--x_disc_size', type=int, help='Number of most informative samples to use for discovery (proposed method)')
    parser.add_argument('--num_samples_select_best', type=int, help='Number of random samples for objective selection phase (proposed method)')
    parser.add_argument('--combination_function_type', type=str, help='Type of combination function: linear, linear_regression, gradient_boosting, mlp (proposed method)')
    parser.add_argument('--train_test_split_idx', type=int, help='Index to split model sequence for train/test (proposed method)')
    parser.add_argument('--ground_truth_model', type=str, help='Model for ground truth reward function (proposed method)')
    parser.add_argument('--ground_truth_combiner_type', type=str, help='Combiner type for ground truth reward (proposed method)')
    parser.add_argument('--normalize_scores', type=bool, help='Normalize all scores to [0, 1] range for objective discovery')
    # Verification parameters
    parser.add_argument('--verifier_epsilon_interpretable', type=float, help='Threshold for human-interpretability verification')
    parser.add_argument('--verifier_epsilon_trend', type=float, help='Threshold for predictable trend verification')
    parser.add_argument('--verification_sample_size', type=int, help='Number of samples to use for objective verification')
    # Model parameters
    parser.add_argument('--proposer_model', type=str, help='Model to use for proposing objectives')
    parser.add_argument('--scorer_model_name', type=str, help='Model name to use for scoring objectives')
    parser.add_argument('--human_scorer_models', nargs='+', help='List of models for human scoring in verification (API or HuggingFace)')
    parser.add_argument('--use_api_proposer', type=bool, help='Whether to use API for proposer (vs local model)')
    # Evaluation parameters
    parser.add_argument('--ground_truth_objectives', nargs='+', help='List of ground-truth objectives for evaluation')
    parser.add_argument('--ground_truth_file', type=str, help='JSON file containing ground-truth objectives')
    parser.add_argument('--evaluation_model', type=str, help='Model to use for evaluation (e.g., gpt-4o-mini)')
    parser.add_argument('--similarity_threshold', type=float, help='Similarity threshold for true positive (1-10)')
    parser.add_argument('--duplicate_threshold', type=float, help='Threshold for duplicate detection (1-10)')
    parser.add_argument('--evaluate_iterations', action='store_true', help='Evaluate at each iteration')
    parser.add_argument('--num_samples_final_eval', type=int, help='Number of samples for final evaluation')
    # Output parameters
    parser.add_argument('--output_dir', type=str, help='Directory to save results')
    parser.add_argument('--exp_name', type=str, help='Experiment name')
    parser.add_argument('--exp_alignment_algo', type=str, help='Experiment alignment algorithm name')
    parser.add_argument('--exp_alignment_model', type=str, help='Experiment alignment model name')
    parser.add_argument('--save_results', type=bool, help='Whether to save results to files')
    # Reward combiner parameters
    parser.add_argument('--use_reward_combiner_weights', type=bool, help='Use fitted reward combiner weights in PPO/GRPO configs')
    # Concurrency parameters
    parser.add_argument('--max_concurrent', type=int, default=10, help='Maximum number of concurrent API calls (default: 20)')
    # Other parameters
    parser.add_argument('--seed', type=int, help='Random seed for reproducibility')
    parser.add_argument('--device', type=str, help='Device to use (auto, cuda, cpu)')
    parser.add_argument('--verbose', action='store_true', help='Enable verbose output')
    # Prompt length filtering
    parser.add_argument('--max_prompt_length', type=int, help='Maximum prompt length in tokens for filtering')
    parser.add_argument('--base_model_name', type=str, help='Base model name for tokenizer (required if max_prompt_length is set)')
    
    args = parser.parse_args()
    
    # Load configuration
    if os.path.exists(args.config):
        print(f"Loading configuration from: {args.config}")
        config = load_config(args.config)
    else:
        print(f"Warning: Config file {args.config} not found. Using command-line arguments only.")
        config = {}
    
    # Merge command-line arguments with config
    config = merge_args_with_config(args, config)
    
    # Setup output directory and logging early
    # config['output_dir'] = config['output_dir'] + f'_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
    config['output_dir'] = os.path.join(config['output_dir'], config['exp_name'] + f'_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
    if not os.path.exists(config['output_dir']):
        print(f"Creating output directory: {config['output_dir']}")
        os.makedirs(config['output_dir'], exist_ok=True)

    # Save the full merged config (with all command-line overrides) to output directory
    config_save_path = os.path.join(config['output_dir'], 'config_used.yaml')
    with open(config_save_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False, sort_keys=False)
    print(f"Config saved to: {config_save_path}")

    # Initialize logging system
    logger = setup_logging(config['output_dir'])
    logger.info("Starting Objective Discovery Pipeline")
    
    # Log configuration
    log_section_header(logger, "CONFIGURATION", level=2)
    log_dict_pretty(logger, config, title="Full Configuration")
    
    # Validate configuration
    try:
        validate_config(config)
    except ValueError as e:
        print(f"Configuration error: {e}")
        return 1
    
    # Set random seed if specified
    # if config.get('seed'):
    #     print(f"Setting random seed: {config['seed']}")
    #     set_seed(config['seed'])

    # Load ground-truth objectives if provided
    ground_truth_objectives = load_ground_truth_objectives(config)
    evaluation_metrics_history = []
    
    # Print configuration summary
    print("\n" + "=" * 80)
    print("OBJECTIVE DISCOVERY CONFIGURATION")
    print("=" * 80)
    print(f"Method: {config['method']}")
    print(f"Dataset: {config['dataset']}")
    print(f"Model sequence ({len(config['model_sequence'])} checkpoints):")
    for i, checkpoint in enumerate(config['model_sequence'], 1):
        print(f"  {i}. {checkpoint}")
    print(f"Target objectives (k): {config['k']}")
    print(f"Output directory: {config['output_dir']}")
    
    if ground_truth_objectives:
        print(f"\nEvaluation enabled with {len(ground_truth_objectives)} ground-truth objectives:")
        for i, obj in enumerate(ground_truth_objectives, 1):
            print(f"  {i}. {obj}")
        print(f"Evaluation model: {config.get('evaluation_model', 'gpt-4o-mini')}")
        print(f"Similarity threshold: {config.get('similarity_threshold', 7.0)}")
        print(f"Duplicate threshold: {config.get('duplicate_threshold', 7.0)}")
        print(f"Evaluate at iterations: {config.get('evaluate_iterations', False)}")
    
    if config.get('verbose'):
        print("\nFull configuration:")
        for key, value in config.items():
            if key not in ['model_sequence', 'ground_truth_objectives']:  # Already printed above
                print(f"  {key}: {value}")
    
    print("=" * 80 + "\n")

    # Initialize discovery method
    try:
        logger.info("Initializing discovery method...")
        discovery = initialize_discovery_method(config, logger)
        logger.info(f"✓ Initialized {config['method']} discovery method")
    except Exception as e:
        print(f"Error initializing discovery method: {e}")
        return 1

    # Run discovery
    try:
        print("\nStarting objective discovery...")
        print("-" * 60)
        objectives, stats = discovery.obtain_objectives()
        print("-" * 60)
        print("✓ Discovery complete!")
    except KeyboardInterrupt:
        print("\n\nDiscovery interrupted by user.")
        return 1
    except Exception as e:
        print(f"\nError during discovery: {e}")
        import traceback
        if config.get('verbose'):
            traceback.print_exc()
        return 1

    # Display results
    print("\n" + "=" * 80)
    print("RESULTS")
    print("=" * 80)

    if objectives:
        print(f"\nDiscovered {len(objectives)} objective(s):\n")
        for i, obj in enumerate(objectives, 1):
            print(f"{i}. {obj}\n")

        # Special handling for obtain_fixed method - save objectives to txt file
        if config.get('method') == 'obtain_fixed':
            # Save objectives to txt file
            dataset_name = config.get('dataset', 'unknown').replace('/', '_')
            txt_filename = f"fixed_objs_{config.get('exp_alignment_algo')}_{dataset_name}.txt"
            txt_path = os.path.join(config['output_dir'], txt_filename)

            with open(txt_path, 'w') as f:
                for obj in objectives:
                    f.write(f"{obj}\n")

            print(f"\n✓ Objectives saved to: {txt_path}")
            logger.info(f"Objectives saved to: {txt_path}")

        # Automatically create YAML config for PPO/GRPO run (for other methods)
        elif config.get('exp_alignment_algo') in ['ppo', 'grpo']:
            try:
                # Get SFT model path from constants
                model_key = (config.get('exp_alignment_model'), config.get('dataset'), config.get('multi_turn'))
                sft_model_path = SFT_MODELS_FILEPATHS.get(model_key)

                if not sft_model_path:
                    logger.warning(f"No SFT model found for {model_key}")
                    print(f"\nWarning: No SFT model found for {model_key}")
                else:
                    # Determine output directory based on algorithm
                    algo = config.get('exp_alignment_algo')
                    if algo == 'ppo':
                        config_dir = ''
                    else:  # grpo
                        config_dir = ''

                    # Find next available experiment number
                    exp_name = config.get('exp_name')
                    exp_num = 1
                    while True:
                        config_filename = f"{exp_name}_{exp_num}.yaml"
                        config_path = os.path.join(config_dir, config_filename)
                        if not os.path.exists(config_path):
                            break
                        exp_num += 1

                    # Convert objectives to lowercase for config files
                    objectives_lower = [obj.lower() for obj in objectives]

                    # Create equal weights for discovered objectives
                    if len(objectives_lower) == 3:
                        weight_per_objective = 0.33
                    else:
                        weight_per_objective = 1.0 / len(objectives_lower) if objectives_lower else 1.0
                    objective_weights = {obj: weight_per_objective for obj in objectives_lower}

                    # Get the saved reward combiner path if available (only if use_reward_combiner_weights is True)
                    reward_combiner_path = None
                    if config.get('use_reward_combiner_weights', False):
                        reward_combiner_path = stats.get('reward_combiner_path', None)

                    # Determine the reward combiner type based on use_reward_combiner_weights
                    if config.get('use_reward_combiner_weights', False):
                        # Use the combination_function_type from config
                        reward_combiner_type = config.get('combination_function_type', 'linear')
                    else:
                        # Default to linear
                        reward_combiner_type = 'linear'

                    # Determine max_ppo_steps based on ground_truth_type
                    ground_truth_type = config.get('ground_truth_type', 'llm_function')
                    if ground_truth_type == 'llm_function':
                        max_ppo_steps = 101
                        max_grpo_steps = 301
                        use_score_norm = True
                        use_score_scaling = True
                        reward_model_name = 'gpt-4o-mini'
                        use_custom_scheduler = False
                        warmup_ratio = 0.025
                    elif ground_truth_type == 'reward_model':
                        if 'PQA2A3' in exp_name:
                            max_ppo_steps = 226
                        max_ppo_steps = 301
                        max_grpo_steps = 501
                        use_score_norm = False
                        use_score_scaling = False
                        reward_model_name = 'gpt-4.1-mini'
                        use_custom_scheduler = True
                        warmup_ratio = 0.16611295681
                    else:
                        # Default to 81 for unknown types
                        max_ppo_steps = 81
                        logger.warning(f"Unknown ground_truth_type '{ground_truth_type}', defaulting to max_ppo_steps=81")

                    # Create config based on algorithm type
                    if algo == 'ppo':
                        # Create PPO config content
                        ppo_content = f"""# PPO Training Configuration - Auto-generated from objective discovery

# Basic model configuration
base_model_path: '{config.get('exp_alignment_model')}'
model_id_path: '{sft_model_path}'
model_save_path: ''

# Dataset configuration
dataset_name: '{config.get('dataset')}'
dataset_max_length: {DATASET_MAX_LENGTH_DICT[config.get('dataset')]}
multi_turn: {config.get('multi_turn', False)}  # False for single-turn (first assistant response only), True for multi-turn

# Wandb configuration
wandb_run_name: '{exp_name}_{exp_num}'
local_wandb_dir: ''

# Training parameters
use_sanity_check: false
max_ppo_steps: {max_ppo_steps}
eval_freq: 5
log_filename: 'logs.txt'

# PPO hyperparameters
learning_rate: 0.000002  # 2e-6
max_ppo_epochs: 1
mini_batch_size: 4
batch_size: 32
init_kl_coef: 0.0325
max_grad_norm: 1.0
target_kl: 1.0
adap_kl_ctrl: false
use_score_scaling: {use_score_scaling}
use_score_norm: {use_score_norm}
score_clip: None
early_stopping: false
whiten_rewards: true
gradient_accumulation_steps: 2
use_custom_scheduler: {use_custom_scheduler}
warmup_ratio: {warmup_ratio}

# Generation parameters
max_new_tokens: 512
missing_eos_penalty: 1.0

# Reward configuration - Discovered objectives
ground_truth_type: 'llm_function'
ground_truth_objectives:
"""
                        for obj in objectives_lower:
                            ppo_content += f"  - \"{obj}\"\n"

                        ppo_content += f"""
ground_truth_weights:
"""
                        for obj in objectives_lower:
                            ppo_content += f"  {obj}: {weight_per_objective:.2f}\n"

                        ppo_content += f"""
# Reward model configuration
reward_model_name: '{reward_model_name}'
use_api: true
reward_combiner_type: '{reward_combiner_type}'
manual_bias: 0.0
use_detailed_rubric: true
"""
                        # Add reward combiner path if available
                        if reward_combiner_path:
                            ppo_content += f"reward_combiner_path: '{reward_combiner_path}'\n"
                            print(f"  Using fitted reward combiner from: {reward_combiner_path}")

                        ppo_content += f"cache_dir: '{config['output_dir']}'\n"
                        ppo_content += f"max_concurrent: {config.get('max_concurrent', 10)}\n"

                        # Write PPO config
                        os.makedirs(config_dir, exist_ok=True)
                        with open(config_path, 'w') as f:
                            f.write(ppo_content)

                    elif algo == 'grpo':  # grpo
                        # Determine dataset type
                        dataset_type = DATASET_NAMES_DICT[config.get('dataset')]

                        # Create GRPO config content
                        grpo_content = f"""# GRPO Training Configuration - Auto-generated from objective discovery
method: 'grpo'
model_name: '{config.get('exp_alignment_model')}'
model_id: '{sft_model_path}'
dataset_name: '{config.get('dataset')}'
dataset_dirs: ''
multi_turn: {config.get('multi_turn', False)}  # False for single-turn (first assistant response only), True for multi-turn
root_save_dir: ''
trust_remote_code: true
dataset_num_proc: 1
local_wandb_dir: ''
val_dataset_size: 16
max_prompt_length: {DATASET_MAX_LENGTH_DICT[config.get('dataset')]}

# BitsAndBytesConfig parameters
load_in_4bit: true
bnb_4bit_use_double_quant: true
bnb_4bit_quant_type: 'nf4'
bnb_4bit_compute_dtype: 'bfloat16'

# AutoModelForCausalLM parameters
device_map: 'auto'
use_cache: false
attn_implementation: 'flash_attention_2'
torch_dtype: 'bfloat16'

# LoraConfig parameters
lora_alpha: 128
lora_dropout: 0.05
lora_r: 256
lora_bias: 'none'
lora_target_modules: 'all-linear'
lora_task_type: 'CAUSAL_LM'

# Training parameters
num_train_epochs: 1
max_steps: {max_grpo_steps}
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
gradient_accumulation_steps: 1
save_strategy: 'steps'
save_steps: 5
eval_strategy: 'steps'
eval_steps: {max_grpo_steps//2}
logging_strategy: 'steps'
logging_steps: 5
push_to_hub: false
report_to: 'wandb'

# Reward function configuration
reward_type: 'llm_function'

# For LLM-based reward
reward_model_name: 'gpt-4o-mini'
use_api: true
reward_combiner_type: '{reward_combiner_type}'
reward_objectives:"""

                        # Add objectives to GRPO config
                        for obj in objectives_lower:
                            if "'" in obj:
                                grpo_content += f'\n  - "{obj}"'
                            else:
                                grpo_content += f"\n  - '{obj}'"

                        grpo_content += f"""
reward_manual_weights:
"""
                        for obj in objectives_lower:
                            grpo_content += f"  {obj}: {weight_per_objective:.2f}\n"

                        grpo_content += f"""reward_manual_bias: 0.0
reward_max_length: 512
dataset_type: '{dataset_type}'
use_detailed_rubric: true
"""
                        # Add reward combiner path if available
                        if reward_combiner_path:
                            grpo_content += f"reward_combiner_path: '{reward_combiner_path}'\n"
                            print(f"  Using fitted reward combiner from: {reward_combiner_path}")

                        grpo_content += f"cache_dir: '{config['output_dir']}'\n"
                        grpo_content += f"max_concurrent: {config.get('max_concurrent', 10)}\n"

                        # Write GRPO config
                        os.makedirs(config_dir, exist_ok=True)
                        with open(config_path, 'w') as f:
                            f.write(grpo_content)

                    else:
                        raise NotImplementedError(f"Unknown alignment algorithm: {algo}")

                    print(f"\n✓ Created {algo.upper()} config file: {config_path}")
                    logger.info(f"Created {algo.upper()} config file: {config_path}")

            except Exception as e:
                logger.error(f"Error creating {algo.upper()} config: {e}")
                print(f"\nWarning: Could not create {algo.upper()} config: {e}")

    else:
        print("\nNo objectives discovered.")
    
    print("\nStatistics:")
    print(f"  Total proposals: {stats.get('total_proposals', 'N/A')}")
    print(f"  Acceptance rate: {stats.get('acceptance_rate', 0):.2%}")
    print(f"  Iterations: {stats.get('total_iterations', 'N/A')}")
    print(f"  Time elapsed: {stats.get('time_elapsed', 0):.2f} seconds")
    
    # Save results if requested
    if config.get('save_results', True):
        print("\nSaving results...")
        save_results(objectives, stats, config, config['output_dir'])
    
    print("\n✓ All done!")

if __name__ == "__main__":
    main()

# OLD PPO HYPERPARAMETERS
# learning_rate: 0.000002  # 2e-6
# max_ppo_epochs: 4
# mini_batch_size: 4
# batch_size: 32
# init_kl_coef: 0.2
# max_grad_norm: 0.5
# target_kl: 1.0
# adap_kl_ctrl: true
# use_score_scaling: true
# use_score_norm: true
# score_clip: None
# early_stopping: true
# whiten_rewards: false
# gradient_accumulation_steps: 2
# use_custom_scheduler: false
# warmup_ratio: 0.025