"""
This script provides functions and configurations for running hyperparameter optimization using Optuna and PyTorch Lightning.

The script includes:
- `create_objective_function`: A function to create an objective function for hyperparameter optimization using Optuna.
- `hyperparameter_search`: A function to perform hyperparameter search with the given model and dataset.

Dependencies:
- optuna: Used for hyperparameter optimization.
- lightning: Used for managing model training.
- utils.path_utils: For obtaining directory paths.
- runs.hyperparameter_search.config_factory: To retrieve model configurations.

Usage:
This script should be used when tuning hyperparameters for deep learning models using Optuna and PyTorch Lightning.
"""

import os
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

import optuna
from optuna.integration import PyTorchLightningPruningCallback

from utils.path_utils import get_directory_path
from runs.hyperparameter_search.model_configurations import get_model_config


class Objective:
    def __init__(self, model_name, options, devices, accelerator, test_mode):
        """
        Initialize the objective function for Optuna hyperparameter optimization.

        Args:
            model_name (str): The name of the model to optimize.
            options (dict): A dictionary containing several possible configuration combinations for the model.
            devices (int or list): The devices to use for training.
            accelerator (str): The type of hardware accelerator (e.g., 'gpu', 'cpu').
        """
        self.model_name = model_name
        self.options = options
        self.devices = devices
        self.accelerator = accelerator
        self.test_mode = test_mode

    def __call__(self, trial: optuna.Trial):  # Add type hint for clarity
        # Retrieve model and data module configuration
        # Assuming get_model_config prepares a config object ready for the model and datamodule
        config = get_model_config(
            trial, self.model_name, self.options, self.devices, self.accelerator
        )
        data_module = config.data_module(config)
        # data_module.setup() # Setup is usually called internally by trainer.fit

        # Ensure num_classes is set correctly *before* initializing the model
        # This might be better handled within get_model_config or datamodule preparation
        # For example:
        # data_module.prepare_data() # If needed
        # data_module.setup('fit') # Call setup explicitly if needed before accessing properties
        # config.num_classes = data_module.num_classes # Assuming datamodule exposes this
        # config.num_stats_features = 2 # Assuming this is static or derived elsewhere

        model = config.model(config)  # Model init now uses potentially updated config

        # Generate trial-specific logging directory
        trial_id = trial.number
        # Ensure dataset_name exists in options
        dataset_name_str = (
            "-".join(self.options["dataset_name"])
            if isinstance(self.options["dataset_name"], list)
            else self.options["dataset_name"]
        )
        if self.test_mode:
            log_dir = os.path.join(
                get_directory_path("model_outputs"),
                f"test_{self.model_name}",
            )
        else:
            log_dir = os.path.join(
                get_directory_path("model_outputs"),
                f"{dataset_name_str}_{self.model_name}",  # Use formatted dataset name
            )

        # Setup TensorBoard logger
        tensorboard_logger = TensorBoardLogger(log_dir, name=f"trial_{trial_id}")
        # Get the actual log dir created by the logger (important!)
        tb_log_dir = tensorboard_logger.log_dir

        # --- Model checkpoint callbacks with updated names ---
        checkpoint_callback_val_f1 = ModelCheckpoint(
            monitor="metric_val_f1",  # Use underscore naming
            mode="max",  # Higher F1 is better
            save_top_k=3,  # Keep original save_top_k
            dirpath=os.path.join(tb_log_dir, "checkpoints"),
            # Use updated metric name in filename, suggest .4f for F1
            filename="best_val_f1-{epoch:03d}-{metric_val_f1:.4f}",
            auto_insert_metric_name=False,  # Good practice when format matches monitor
        )

        checkpoint_callback_train_loss = ModelCheckpoint(
            monitor="loss_train_monitor",  # Use underscore naming (assuming epoch level)
            mode="min",  # Lower loss is better
            save_top_k=3,  # Keep original save_top_k
            dirpath=os.path.join(tb_log_dir, "checkpoints"),
            # Use updated metric name in filename
            filename="best_train_loss-{epoch:03d}-{loss_train_monitor:.5f}",
            auto_insert_metric_name=False,
        )

        # --- Update commented-out callbacks too (good practice) ---

        # # Pruning callback for Optuna integration
        # pruning_callback = PyTorchLightningPruningCallback(
        #     trial, monitor="metric_val_f1" # Use updated metric name
        # )

        # # Early stopping callback configuration
        # # Usually monitor validation metric for early stopping
        # early_stopping_callback = EarlyStopping(
        #     monitor="metric_val_f1",      # Use updated val metric name
        #     patience=config.stop_patience, # Assuming patience is in config
        #     mode="max",                   # Match the monitor's mode
        #     verbose=True,
        # )

        # --- Trainer configuration ---
        trainer = L.Trainer(
            max_epochs=1000,  # Consider making this configurable via config/options
            accelerator=config.accelerator,
            devices=config.devices,
            log_every_n_steps=config.log_every_n_steps,  # Assuming in config
            logger=tensorboard_logger,
            callbacks=[
                checkpoint_callback_val_f1,  # Use updated variable name
                checkpoint_callback_train_loss,  # Use updated variable name
                # pruning_callback,           # Uncomment if needed
                # early_stopping_callback     # Uncomment if needed
            ],
            accumulate_grad_batches=config.accumulation_steps,
        )
        trainer.fit(model, datamodule=data_module)

        # --- Train the model ---
        try:
            trainer.fit(model, datamodule=data_module)
        except Exception as e:
            print(f"Error during training for trial {trial_id}: {e}")
            # Optuna needs a value returned, even on failure.
            # Returning a value indicating failure (e.g., infinity for minimization, -infinity for maximization)
            # or handle pruning based on intermediate values if applicable.
            # For maximization, return a very low number.
            return -float("inf")

        # --- Retrieve the best validation score for Optuna ---
        # The primary return value should be the score of the metric Optuna is optimizing,
        # which is typically the one monitored by the main checkpoint callback.
        best_score = checkpoint_callback_val_f1.best_model_score

        if best_score is not None:
            return best_score.item()
        else:
            # Fallback: if no checkpoint was saved (e.g., training failed early or only 1 epoch)
            # Try to get the last logged value of the monitored metric.
            # Note: callback_metrics might only contain last step values, check documentation
            # or use a more robust way if needed (e.g., accessing logger history)
            last_val_f1 = trainer.callback_metrics.get("metric_val_f1")
            if last_val_f1 is not None:
                print(
                    f"Warning: Trial {trial_id} finished without a best checkpoint score. Using last logged metric_val_f1: {last_val_f1.item()}"
                )
                return last_val_f1.item()
            else:
                # Even further fallback, maybe return a default bad value
                print(
                    f"Warning: Trial {trial_id} finished without a best checkpoint score and couldn't find last logged metric_val_f1."
                )
                return 0.0


def hyperparameter_search(model_name, options, devices, accelerator, optuna_config, test_mode=True):
    """
    Run hyperparameter search using Optuna.

    Args:
        model_name (str): The name of the model to optimize.
        options (dict): A dictionary containing several possible configuration combinations for the model.
        devices (int or list): The devices to use for training.
        accelerator (str): The type of hardware accelerator (e.g., 'gpu', 'cpu').
        optuna_config (OptunaConfig): Configuration for Optuna optimization.
    """
    if test_mode:
        study_name = f"test_{model_name}"
    else:
        study_name = f"{options['dataset_name']}_{model_name}"

    # Set up the pruner
    # pruner = optuna.pruners.HyperbandPruner(
    #     min_resource=optuna_config.min_epochs, max_resource=optuna_config.max_epochs
    # )

    pruner = optuna.pruners.NopPruner()

    # Create or load the study
    study = optuna.create_study(
        study_name=study_name,
        storage=optuna_config.storage,
        direction="maximize",
        pruner=pruner,
        sampler=optuna_config.sampler,
        load_if_exists=True,
    )

    # Start the optimization
    study.optimize(
        Objective(model_name, options, devices, accelerator, test_mode),
        n_trials=optuna_config.n_trials,
        timeout=optuna_config.timeout,
    )
