"""Main training script for neural operator models.

This script sets up and runs the training process for neural operator models
with support for distributed training, automatic batch size probing, and
comprehensive logging.
"""

import argparse
import os

import src.utils.wandb_utils
import wandb
from src.config import Config
from src.dataset.dataloader_factory import build_data_loaders
from src.models.model_factory import build_model
from src.oned_fixed_inferencer import OneDFixedInferencer
from src.trainer import Trainer
from src.utils.logger import CustomLogger
from src.utils.logger_ctx import set_logger
from src.utils.training_utils import (
    convert_to_debug_config,
    set_random_seed,
    get_optimizer,
    get_scheduler,
    get_loss_function,
)


def main() -> None:
    """Main training function.

    This function orchestrates the complete training pipeline including:
    - Configuration loading and validation
    - Model and data loader creation
    - Training execution
    """

    # --------------------------------------------------------------------------- Load environment variables and config

    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Train neural operator models")
    parser.add_argument("--config")
    parser.add_argument("--debug", action="store_true", help="Run in debug mode with reduced dataset and epochs")
    args = parser.parse_args()

    # Load and process configuration
    config = Config.from_yaml(args.config)

    config.debug = args.debug or config.debug
    config = convert_to_debug_config(config) if config.debug else config
    # -------------------------------------------------------------------- Initialize Weights & Biases (possibly sweep)

    # Only main process initializes wandb
    if config.training.use_wandb:
        run = wandb.init(
            mode="offline",
            config=config.to_dict(),
            **config.training.wandb
        )
        config.output_dir = os.path.join(config.output_dir, run.id)
    else:
        run = None

    # If we load an old model, we need to ensure the config is updated
    if config.model.checkpoint_path is not None:
        checkpoint = src.utils.wandb_utils.load_checkpoint(config.model.checkpoint_path)
        config = src.utils.wandb_utils.overwrite_model_config(config, checkpoint)
    else:
        checkpoint = None

    # -------------------------------------------------------------------------------------------------- Set up Logging
    logger = CustomLogger(log_dir="logs", run=run)
    set_logger(logger)

    if config.debug:
        logger.info("Running in debug mode. Modifying configuration for quick testing.")

    logger.info(f"Training with device: {config.training.device}")

    # ------------------------------------------------------------------------ Build model, data loaders, and optimizer
    set_random_seed(config.seed)

    state_dict = checkpoint['model_state_dict'] if checkpoint is not None else None
    model = build_model(config, state_dict=state_dict)

    loaders = build_data_loaders(batch_size=config.training.batch_size, n_workers=config.training.n_workers,
                                 debug=config.debug, **config.to_dict()['dataset'])

    criterion, val_criterion = get_loss_function(config.training.loss, config.training.val_loss)
    optimizer = get_optimizer(model, config)
    scheduler = get_scheduler(optimizer, config)
    inference_engine = OneDFixedInferencer(model, loaders['val'], config)

    # ------------------------------------------------------------------------------------------------- Train the model

    trainer = Trainer(
        model=model,
        config=config,
        logger=logger,
        train_loader=loaders["train"],
        val_loader=loaders["val"],
        criterion=criterion,
        val_criterion=val_criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        inference_engine=inference_engine,
        checkpoint=checkpoint,
    )

    trainer.train()


if __name__ == "__main__":
    main()
