# trainer_module.py
import torch
import time
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


def train_model(
        model,
        train_loader,
        test_loader,
        criterion,
        optimizer,
        scheduler,
        config,  # Dict containing: 'num_epochs', 'log_interval', 'repeat' (for total runs)
        device,
        num_classes,
        output_dir: Path,  # Specific output directory for this run
        eval_fn,  # Callable for evaluate_accuracy (e.g., from metrics_utils)
        nc_fn,  # Callable for measure_neural_collapse (e.g., from metrics_utils)
        save_metrics_fn,  # Callable for save_metrics_to_csv (e.g., from metrics_utils)
        repeat_idx: int = 0,
        run_seed: int = None
):
    """Train the model and evaluate after each epoch, following the new structure."""

    epochs = config['num_epochs']
    log_interval = config['log_interval']
    # Total number of repetitions for logging purposes (e.g., "Run 1/5")
    repeat_total_runs = config.get('repeat', 1)

    # Ensure output_dir is a Path object for consistency
    output_dir_path = Path(output_dir)
    metrics_file = output_dir_path / f'metrics_run_{repeat_idx}_seed_{run_seed}.csv'
    best_accuracy = 0.0

    for epoch in range(epochs):
        epoch_start_time = time.time()
        model.train()
        running_loss = 0.0
        correct_train = 0  # Renamed from 'correct' in snippet to avoid clash if eval_fn used it
        total_train = 0  # Renamed from 'total'

        # Training phase
        run_info_prefix = f"Run {repeat_idx + 1}/{repeat_total_runs}"
        logger.info(f"{run_info_prefix} - Epoch {epoch + 1}/{epochs} - Training...")

        current_lr = scheduler.get_last_lr()[0]
        logger.info(f"Current learning rate: {current_lr:.6e}")

        for batch_idx, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            # Assuming model returns (outputs, features) as per nc_models.py structure
            outputs, _ = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Track statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

            # Log batch progress
            if (batch_idx + 1) % log_interval == 0:
                if total_train > 0:
                    batch_acc = 100 * correct_train / total_train
                    logger.info(f"{run_info_prefix} - "
                                f"Epoch [{epoch + 1}/{epochs}] - Batch [{batch_idx + 1}/{len(train_loader)}] "
                                f"Loss: {loss.item():.4f}, Running Acc: {batch_acc:.2f}%")
                else:
                    logger.info(f"{run_info_prefix} - "
                                f"Epoch [{epoch + 1}/{epochs}] - Batch [{batch_idx + 1}/{len(train_loader)}] "
                                f"Loss: {loss.item():.4f}, Running Acc: N/A (0 samples)")

        # Calculate epoch metrics for training
        epoch_loss = running_loss / len(train_loader) if len(train_loader) > 0 else 0.0
        train_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0

        # Step the scheduler (learning rate decay)
        scheduler.step()  # Call scheduler.step() after optimizer.step() per PyTorch recommendations for most schedulers

        # Evaluate on test set
        logger.info(f"{run_info_prefix} - "
                    f"Epoch {epoch + 1}/{epochs} - Evaluating...")

        # Use the passed eval_fn, nc_fn which are expected to be configured with device, num_classes etc.
        test_accuracy = eval_fn(
            model, test_loader, f"{run_info_prefix} - Testing Epoch {epoch + 1}"
        )
        nc1, nc2, nc3 = nc_fn(model, train_loader) # NC usually on train_loader

        if nc1 is None:  # Check for divergence based on NC1 calculation
            logger.warning(
                f"{run_info_prefix} - Epoch {epoch + 1}/{epochs} - NC1 calculation failed (potential divergence). Last training loss: {loss.item():.4f}"
            )
            # Save metrics up to this point before returning
            metrics_partial = {
                'train_loss': epoch_loss,
                'train_acc': train_accuracy,
                'test_acc': test_accuracy,
                'nc1': float('nan'),
                'nc2': float('nan'),
                'nc3': float('nan'),
                'epoch_time': time.time() - epoch_start_time,
                'learning_rate': current_lr  # LR before scheduler.step() for this epoch training
            }
            save_metrics_fn(metrics_partial, epoch, metrics_file)
            return False, best_accuracy, metrics_file  # Indicate divergence

        epoch_time_taken = time.time() - epoch_start_time

        # Save metrics
        # Ensure NC metrics are float (using .item()) and handle None for robust CSV saving
        metrics_data = {
            'train_loss': epoch_loss,
            'train_acc': train_accuracy,
            'test_acc': test_accuracy,
            'nc1': nc1.item() if nc1 is not None else float('nan'),
            'nc2': nc2.item() if nc2 is not None else float('nan'),
            'nc3': nc3.item() if nc3 is not None else float('nan'),
            'epoch_time': epoch_time_taken,  # Using 'epoch_time' for consistency with potential aggregation scripts
            'learning_rate': current_lr  # LR used for this epoch's training
        }
        save_metrics_fn(metrics_data, epoch, metrics_file)

        # Log epoch summary
        logger.info(f"{run_info_prefix} - Epoch {epoch + 1}/{epochs} Summary:")
        logger.info(f"  Training Loss: {epoch_loss:.4f}")
        logger.info(f"  Training Accuracy: {train_accuracy:.2f}%")
        logger.info(f"  Test Accuracy: {test_accuracy:.2f}%")
        if nc1 is not None:  # Only log NC if successfully computed
            logger.info(f"  Neural Collapse NC1 (Sw/Sb) on Train data: {metrics_data['nc1']:.4f}")
            logger.info(f"  Neural Collapse NC2 (Subspace Alignment) on Train data: {metrics_data['nc2']:.4f}")
            logger.info(f"  Neural Collapse NC3 (Classifier Alignment) on Train data: {metrics_data['nc3']:.4f}")
        else:
            logger.info("  Neural Collapse metrics: Not computed (NC1 was None)")
        logger.info(f"  Time: {epoch_time_taken:.2f}s")

        # Track best accuracy
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            logger.info(f"{run_info_prefix} - New best accuracy: {best_accuracy:.2f}% (Epoch {epoch + 1})")

    logger.info(f"{run_info_prefix} - Training completed. Best test accuracy for this run: {best_accuracy:.2f}%")
    return True, best_accuracy, metrics_file  # Indicate success