import os
import math
import torch
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from utils.seed import set_seed
from utils.path_utils import get_directory_path
from runs.hyperparameter_search.model_configurations import MoVQFormerConfig

from models.MoVQFormer import l_model as MoVQFormer

import os
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint


from utils.path_utils import get_directory_path
from runs.hyperparameter_search.model_configurations import MoVQFormerConfig


class MockTrial:
    """Minimalist mock of optuna.Trial with just enough functionality to not raise errors."""

    def __init__(self, number=0):
        self.number = number

    def set_user_attr(self, key, value):
        return self

    def suggest_categorical(self, name, choices):
        return choices[0]

    def suggest_int(self, name, low, high, step=1, log=False):
        return low

    def suggest_float(self, name, low, high, step=None, log=False):
        return low

    def report(self, value, step):
        pass

    def should_prune(self):
        return False




devices = [
    # 0,
    # 1,
    # 2,
    # 3,
    4,
    # 5,
    # 6,
    # 7,
]

seed = 2002
set_seed(seed)
accelerator = "cuda"

batch_scale_rate = 1

options = {
    #"pretrain", "finetune", "joint"
    "mode": "finetune",
    "freeze_encoder": True,
    "encoder_lr_factor": 0.999,
    "batch_size": {"device_batch_size":int(32/batch_scale_rate), 
                    "accumulation_steps": int(16*batch_scale_rate)
                    },
    "random_state": seed,
    "dataset_name": [
        # "pamap2",
        # "dsads",
        # "mhealth",
        "realworld2016",
        # "ucihar",
        # "uschad",
    ],
    "split_strategy":"concatenate_split",
    "allowed_activity_labels": None,
    # "allowed_activity_labels": [0, 1, 2, 3, 4, 5],
}

mock_trial = MockTrial()

config = MoVQFormerConfig(mock_trial, "MoVQFormer", options, devices, accelerator)

config.mask_ratio = 0.0
config.split_ratio = [0.2, 0.8 ,0]
config.num_classes = 8
config.record = "cls without pretrain"

config.window_size = [500]
config.stride = [500]



data_module = config.data_module(config)
model = config.model(config)

# Ensure dataset_name exists in options
dataset_name_str = (
    "-".join(options["dataset_name"])
    if isinstance(options["dataset_name"], list)
    else options["dataset_name"]
)

log_dir = os.path.join(
    get_directory_path("model_outputs"),
    "cls_without_pretrain",
    dataset_name_str,
)

# Setup TensorBoard logger
tensorboard_logger = TensorBoardLogger(log_dir)


# --- Trainer configuration ---
trainer = L.Trainer(
    max_epochs=200000,
    accelerator=config.accelerator,
    devices=config.devices,
    log_every_n_steps=config.log_every_n_steps,  # Assuming in config
    logger=tensorboard_logger,
    accumulate_grad_batches=config.accumulation_steps,
)

trainer.fit(model, datamodule=data_module)
