import os
import math
import torch
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from utils.seed import set_seed
from utils.path_utils import get_directory_path
from runs.hyperparameter_search.model_configurations import MoVQFormerConfig

from models.MoVQFormer import l_model as MoVQFormer


def update_config_list_para(config):

    dataset_options = config.dataset_options
    config.original_sample_rate = [dataset_options[name]["original_sample_rate"]for name in config.dataset_name]
    config.num_channels = [dataset_options[name]["num_channels"]for name in config.dataset_name]
    config.target_sample_rate = [config.target_sample_rate[0] for _ in config.dataset_name]

    min_segment_length = []
    for ws, osr, tsr in zip(config.window_size, config.original_sample_rate, config.target_sample_rate):
        if tsr > 0 and osr > 0:
            min_len_in_original = math.ceil(ws * (osr / tsr))
            min_segment_length.append(int(min_len_in_original))
        else:
                min_segment_length.append(ws + 1)
    config.min_segment_length = min_segment_length

    return config

devices = [
    # 0,
    # 1,
    # 2,
    # 3,
    4,
    # 5,
    # 6,
    # 7,
]

seed = 2002
set_seed(seed)
accelerator = "cuda"

batch_scale_rate = 0.125

options = {
    #"pretrain", "finetune", "joint"
    "mode": "finetune",
    "freeze_encoder": True,
    "batch_size": {"device_batch_size":int(32/batch_scale_rate), 
                    "accumulation_steps": int(16*batch_scale_rate)
                    },
    "random_state": seed,
    "dataset_name": [
        # "pamap2",
        # "dsads",
        # "mhealth",
        # "realworld2016",
        # "ucihar",
        "uschad",
    ],
    "split_strategy":"concatenate_split",
    # "split_dataset_assignments": {
    #     "train": ["pamap2", "dsads", "mhealth", "realworld2016", "uschad"],
    #     "val": [],
    #     "test": [],
    # },
    "allowed_activity_labels": None,
    # "allowed_activity_labels": [0, 1, 2, 3, 4, 5],
}


ckpt_path = "epoch-029-0.70658.ckpt"

model_ckpt, config = MoVQFormer.load_from_checkpoint(ckpt_path,map_location="cpu")

config.finetune_source_ckpt_path = ckpt_path
config.devices = devices
config.accelerator = accelerator
config.mode = options["mode"]
config.freeze_encoder = options["freeze_encoder"]
config.dataset_name = options["dataset_name"]
config.split_strategy = options["split_strategy"]
config.random_state = options["random_state"]
config.batch_size = options["batch_size"]["device_batch_size"]
config.accumulation_steps = options["batch_size"]["accumulation_steps"]
config.allowed_activity_labels = options["allowed_activity_labels"]
config.mask_ratio = 0.0
config.split_ratio = [0.2, 0.8 ,0]
config.num_classes = 13

config.record = "fine-tune transformer and pretrain without cls"

config.window_size = [500]
config.stride = [500]

config = update_config_list_para(config)
data_module = config.data_module(config)
model = config.model(config)

model.load_encoder_weights(ckpt_path)

# Ensure dataset_name exists in options
dataset_name_str = (
    "-".join(options["dataset_name"])
    if isinstance(options["dataset_name"], list)
    else options["dataset_name"]
)

log_dir = os.path.join(
    get_directory_path("model_outputs"),
    # "finetune",
    "test",
    dataset_name_str,
)

# Setup TensorBoard logger
tensorboard_logger = TensorBoardLogger(log_dir)
# Get the actual log dir created by the logger (important!)
tb_log_dir = tensorboard_logger.log_dir


# checkpoint_callback_train_loss = ModelCheckpoint(
#     monitor=None,
#     save_top_k=-1,
#     dirpath=os.path.join(tb_log_dir, "checkpoints"),
#     filename="epoch-{epoch:03d}-{metric_val_f1:.5f}",
#     auto_insert_metric_name=False,
#     every_n_epochs=20,
# )


# --- Trainer configuration ---
trainer = L.Trainer(
    max_epochs=20000,
    accelerator=config.accelerator,
    devices=config.devices,
    log_every_n_steps=config.log_every_n_steps,  # Assuming in config
    logger=tensorboard_logger,
    # callbacks=[
    #     checkpoint_callback_train_loss,  # Use updated variable name
    # ],
    accumulate_grad_batches=config.accumulation_steps,
)

trainer.fit(model, datamodule=data_module)