# config_utils.py
import argparse
from datetime import datetime
import torch
import numpy as np
import logging
import sys
from pathlib import Path


def set_random_seeds(seed):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def setup_logging(log_file_path, level=logging.INFO):
    """Configure logging to file and stdout."""
    log_file_path = Path(log_file_path)
    log_file_path.parent.mkdir(parents=True, exist_ok=True)

    logging.basicConfig(
        level=level,
        format='%(asctime)s [%(name)s] [%(levelname)s] %(message)s',
        handlers=[
            logging.FileHandler(log_file_path),
            logging.StreamHandler(sys.stdout)
        ]
    )
    return logging.getLogger() # Get the root logger

def get_device():
    """Get the available device (GPU or CPU) and log information."""
    logger = logging.getLogger(__name__)
    if torch.cuda.is_available():
        device = torch.device('cuda')
        logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
        try: # Fails on some systems like MPS, so try-except
            logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        except Exception as e:
            logger.warning(f"Could not get GPU memory properties: {e}")
        logger.info(f"Number of available GPUs: {torch.cuda.device_count()}")
    else:
        device = torch.device('cpu')
        logger.info("GPU not available, using CPU.")
    return device

def parse_arguments():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description='Neural Collapse Training Experiments')
    parser.add_argument('--model', type=str, default='resnet18',
                        choices=['resnet18', 'resnet50', 'vgg16'],
                        help='Model architecture')
    parser.add_argument('--dataset', type=str, default='mnist',
                        choices=['cifar10', 'cifar100', 'mnist'], help='Dataset')
    parser.add_argument('--batch-size', type=int, default=100, help='Batch size')
    parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--optimizer', type=str, default='adam', choices=['adam', 'sgd'],
                        help='Optimizer to use (adam or sgd)')
    parser.add_argument('--log-interval', type=int, default=100, help='Batches between logging')
    parser.add_argument('--seed', type=int, default=420, help='Random seed for the first run')
    parser.add_argument('--workers', type=int, default=4, help='Number of data loading workers')
    parser.add_argument('--lr-step-size', type=int, default=20,
                        help='Step size for learning rate scheduler (epochs)')
    parser.add_argument('--lr-gamma', type=float, default=0.1,
                        help='Multiplicative factor for learning rate decay')
    parser.add_argument('--layer-norm', type=str, default='rms', choices=['none', 'standard', 'rms'],
                        help='Layer normalization type to use (for ResNet/VGG feature extractor if applicable)')
    parser.add_argument('--repeat', type=int, default=1,
                        help='Number of times to repeat the experiment with incrementing seeds')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='Save the final model state_dict for each run')
    # Removed divergence_threshold as it's not used in the latest train_model logic
    parser.add_argument('--debug', action='store_true', default=False,
                        help='Enable debug mode (e.g., fewer epochs)')
    parser.add_argument('--output-base-dir', type=str, default="output_experiments",
                        help='Base directory for saving experiment outputs')


    args = parser.parse_args()

    if args.debug:
        args.epochs = 3 # Example: debug mode runs for only 3 epochs
        args.log_interval = 10
        # Add other debug overrides if necessary
    return args

def create_run_config(args):
    """Create a configuration dictionary from parsed arguments and generate paths."""
    config = vars(args).copy() # Start with a copy of args namespace

    # Create output directory with timestamp
    # Assuming this script (config_utils.py) is in the same directory as main_experiment.py
    # and output_base_dir is relative to the project root (parent of where scripts are)
    project_root = Path(__file__).resolve().parent
    base_dir = project_root / args.output_base_dir

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    run_name_parts = [args.model, args.dataset, f"lr{args.lr}", f"opt{args.optimizer}", timestamp]
    run_name = "_".join(run_name_parts)

    output_dir = base_dir / run_name
    output_dir.mkdir(parents=True, exist_ok=True)

    config['output_dir'] = output_dir
    config['run_name'] = run_name
    config['log_file'] = output_dir / 'training.log'

    # Update num_epochs in config if debug mode changed it in args
    config['num_epochs'] = args.epochs

    return config