import os
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint


from utils.path_utils import get_directory_path
from runs.hyperparameter_search.model_configurations import MoVQFormerConfig


class MockTrial:
    """Minimalist mock of optuna.Trial with just enough functionality to not raise errors."""

    def __init__(self, number=0):
        self.number = number

    def set_user_attr(self, key, value):
        return self

    def suggest_categorical(self, name, choices):
        return choices[0]

    def suggest_int(self, name, low, high, step=1, log=False):
        return low

    def suggest_float(self, name, low, high, step=None, log=False):
        return low

    def report(self, value, step):
        pass

    def should_prune(self):
        return False


def pretrain_motion_encoder(model_name, options, devices, accelerator):

    mock_trial = MockTrial()

    config = MoVQFormerConfig(mock_trial, model_name, options, devices, accelerator)
    setattr(config, "num_classes", 12)

    data_module = config.data_module(config)
    model = config.model(config)

    # Ensure dataset_name exists in options
    dataset_name_str = (
        "-".join(options["dataset_name"])
        if isinstance(options["dataset_name"], list)
        else options["dataset_name"]
    )

    log_dir = os.path.join(
        get_directory_path("model_outputs"),
        "pretraind_encoder",
        dataset_name_str,
    )

    # Setup TensorBoard logger
    tensorboard_logger = TensorBoardLogger(log_dir)
    # Get the actual log dir created by the logger (important!)
    tb_log_dir = tensorboard_logger.log_dir

    # --- Model checkpoint callbacks with updated names ---

    # checkpoint_callback_train_loss = ModelCheckpoint(
    #     monitor="loss_train_monitor",  # Use underscore naming (assuming epoch level)
    #     mode="min",  # Lower loss is better
    #     save_top_k=10,  # Keep original save_top_k
    #     dirpath=os.path.join(tb_log_dir, "checkpoints"),
    #     # Use updated metric name in filename
    #     filename="best_train_loss-{epoch:03d}-{loss_train_monitor:.5f}",
    #     auto_insert_metric_name=False,
    # )

    checkpoint_callback_train_loss = ModelCheckpoint(
        monitor=None,
        save_top_k=-1,
        dirpath=os.path.join(tb_log_dir, "checkpoints"),
        filename="epoch-{epoch:03d}-{loss_train_monitor:.5f}",
        auto_insert_metric_name=False,
        every_n_epochs=3,
    )


    # --- Trainer configuration ---
    trainer = L.Trainer(
        max_epochs=2000,
        accelerator=config.accelerator,
        devices=config.devices,
        log_every_n_steps=config.log_every_n_steps,  # Assuming in config
        logger=tensorboard_logger,
        callbacks=[
            checkpoint_callback_train_loss,  # Use updated variable name
        ],
        accumulate_grad_batches=config.accumulation_steps,
    )

    trainer.fit(model, datamodule=data_module)