"""
This script provides configuration classes and a factory function for Optuna hyperparameter optimization.

The script includes:
- `OptunaConfig`: A configuration class for setting up Optuna hyperparameter optimization.
- `get_model_config`: A factory function to retrieve the appropriate configuration class based on model name.
- `BaseConfig`: Base configuration class for model training and data processing parameters.
- Specific configurations for different models (`TSBMConfig`, `TimesNetConfig`, `ResNetConfig`).

Dependencies:
- optuna
- data.data_import for loading datasets
- models module for different models
- data.datasets for data modules
"""

import os
import math
import logging

import optuna

from data.datasets import CommonDataModule, MoVQFormerDataModule
from models import MoVQFormer, ResNet, TimesNet, ts
from utils.sensor_config_loader import SensorConfigLoader

logger = logging.getLogger(__name__)

class OptunaConfig:
    """
    Configuration class for setting up Optuna hyperparameter optimization.

    Attributes:
        storage (str): The database URL for storing Optuna studies.
        min_epochs (int): Minimum number of training epochs.
        max_epochs (int): Maximum number of training epochs.
        n_trials (int): Number of trials for the optimization.
        timeout (int): Maximum time for the optimization process (in seconds).
        sampler (optuna.samplers.BaseSampler): Sampler for the search space.
    """

    def __init__(self, sampler=None):
        # self.storage = "mysql+pymysql://optuna:dpnSe556hDjxyPpf@172.18.36.112:33060/optuna"
        self.storage = "mysql+pymysql://optuna_test:123456@172.18.36.112:33060/optuna_test"
        self.min_epochs = 100
        self.max_epochs = 500
        self.n_trials = 1
        self.timeout = 3600 * 24 * 2

        if sampler is None:
            self.sampler = optuna.samplers.RandomSampler()
        else:
            self.sampler = sampler


def get_model_config(trial, model_name, options, devices, accelerator):
    """
    Factory function to get the appropriate configuration class based on the model name.

    Args:
        trial (optuna.Trial): The Optuna trial object.
        model_name (str): The name of the model to optimize.
        options (dict): A dictionary containing several possible configuration combinations for the model.
        devices (list or str): The devices to use for training (e.g., GPU or CPU).
        accelerator (str): Type of hardware accelerator (e.g., 'gpu' or 'cpu').

    Returns:
        BaseConfig: An instance of the appropriate configuration class.
    """
    config_classes = {
        "TSBM": TSBMConfig,
        "MoVQFormer": MoVQFormerConfig,
        "TimesNet": TimesNetConfig,
        "ResNet": ResNetConfig,
    }

    if model_name in config_classes:
        return config_classes[model_name](
            trial, model_name, options, devices, accelerator
        )
    else:
        raise ValueError(f"Model name '{model_name}' is not supported.")




class BaseConfig:
    """
    模型训练和数据处理参数的基础配置类 (保留硬编码值)。

    Attributes:
        devices (list or str): 用于训练的设备。
        accelerator (str): 硬件加速器类型。
        model_name (str): 正在优化的模型名称。
        options (dict): (在此实现中主要用于获取少数初始值，大部分值被硬编码覆盖)。
        log_every_n_steps (int): 训练期间记录日志的频率。
        stop_patience (int): 早停前等待的 epoch 数。
        target_sample_rate (list): 数据下采样后的目标采样率列表。
        original_sample_rate (list): 数据集的原始采样率列表。
        num_channels (list): 数据集的通道数列表。
        window_size (list): 分段的窗口大小列表。
        min_segment_length (list): 下采样后允许的最小段长度列表。
        stride (list): 滑动窗口的步长列表。
        num_classes (int or None): 数据集中的活动类别数。
        split_ratio (list): 训练、验证和测试集的分割比例。
        batch_size (int): 数据加载的批处理大小。
        lr (float): 优化的学习率。
        # ... 其他 DataLoader 参数 ...
    """

    def __init__(self, trial, model_name, options, devices=None, accelerator="cuda"):
        self.devices = devices
        self.accelerator = accelerator
        self.log_every_n_steps = 1
        self.stop_patience = 100
        # 从 options 获取 random_state，如果 options 中没有则使用默认值 42
        # 注意：如果 options 是字典，需要用 options.get('random_state', 42)
        # 如果 options 是对象，用 getattr(options, 'random_state', 42)
        # 这里假设 options 是字典或类似字典的对象
        self.random_state = options.get("random_state", 42)

        self.model_name = model_name

        # 从 options 获取 dataset_name, split_strategy 等，如果 options 没有则为 None
        self.dataset_name = options.get('dataset_name', None)
        self.split_strategy = options.get('split_strategy', None) # 需要确保 DataModule 能处理 None
        self.split_dataset_assignments = options.get('split_dataset_assignments', None)
        self.allowed_activity_labels = options.get('allowed_activity_labels', None)

        # --- 开始硬编码部分 ---
        # 检查 dataset_name 是否有效，否则硬编码无法工作
        if not self.dataset_name:
             # 如果 dataset_name 未提供，硬编码将基于空列表，可能导致错误
             # 最好在此处提供一个默认值或抛出错误
             logger.warning("dataset_name 未在 options 中提供，硬编码可能依赖于它。")
             # 可以设置一个默认值，例如：
             # self.dataset_name = ["pamap2"] # 或者其他默认数据集
             # 或者抛出错误：
             # raise ValueError("必须在 options 中提供 dataset_name 用于配置")
             # 这里假设如果 dataset_name 为 None 或空列表，后续代码能处理

        # 确保 dataset_name 是列表，以便后续列表推导式正常工作
        if isinstance(self.dataset_name, str):
            self.dataset_name = [self.dataset_name]
        elif self.dataset_name is None:
            self.dataset_name = [] # 设置为空列表以避免迭代错误

        # 数据集特定配置 (硬编码)
        self.target_sample_rate = [100 for _ in self.dataset_name]

        # 硬编码的数据集选项
        self.dataset_options = {
            "mhealth": {"original_sample_rate": 50, "num_channels": 18},
            "pamap2": {"original_sample_rate": 100, "num_channels": 36},
            "dsads": {"original_sample_rate": 25, "num_channels": 45},
            "realworld2016": {"original_sample_rate": 50, "num_channels": 63},
            "ucihar": {"original_sample_rate": 50, "num_channels": 6},
            "uschad": {"original_sample_rate": 100, "num_channels": 6},
            # 可以添加更多数据集的默认配置
        }

        # 硬编码窗口大小
        self.window_size = [500 for _ in self.dataset_name]

        # 基于硬编码的 dataset_options 和 window_size 计算 original_sample_rate, num_channels, min_segment_length
        self.original_sample_rate = [
            self.dataset_options.get(name, {}).get("original_sample_rate", 0) # 提供默认值 0 以防数据集不在字典中
            for name in self.dataset_name
        ]
        self.num_channels = [
            self.dataset_options.get(name, {}).get("num_channels", 0) # 提供默认值 0
            for name in self.dataset_name
        ]

        # 计算 min_segment_length 列表
        self.min_segment_length = []
        for ws, osr, tsr in zip(self.window_size, self.original_sample_rate, self.target_sample_rate):
            if tsr > 0 and osr > 0: # 避免除以零或无效计算
                # 注意：原计算可能导致非常大的 min_segment_length，检查是否合理
                # min_len = int(ws * osr / tsr) + 1
                # 考虑一个更常见的计算方式，例如基于目标窗口大小计算原始长度
                min_len_in_original = math.ceil(ws * (osr / tsr)) # 向上取整确保覆盖目标窗口
                self.min_segment_length.append(int(min_len_in_original))
            else:
                 # 如果采样率无效，设置一个基于窗口大小的默认值
                 logger.warning(f"数据集的采样率无效 (osr={osr}, tsr={tsr})，将使用默认 min_segment_length")
                 self.min_segment_length.append(ws + 1) # 至少比窗口大一点

        # 硬编码步长
        self.stride = [100 for _ in self.dataset_name]

        # 类别数 (硬编码或留空) - 模型通常需要这个
        # 您可以尝试从 dataset_options 推断，但这不可靠，最好硬编码或在 options 提供
        if self.allowed_activity_labels:
            self.num_classes = len(self.allowed_activity_labels)
        else:
            self.num_classes = None

        # DataLoader 参数 (硬编码)
        self.split_ratio = [0.2, 0.8, 0] # (训练, 验证, 测试)
        self.batch_size = options.get('batch_size', 4).get('device_batch_size', 4)
        self.accumulation_steps = options.get('batch_size', 1).get('accumulation_steps', 1)
        # self.val_batch_size = self.batch_size
        # self.test_batch_size = self.batch_size

        # 优化参数 (硬编码)
        self.lr = 1e-4

        # 其他 DataLoader/Sampler 参数 (硬编码)
        self.train_shuffle = True
        self.drop_last_train = False # 训练时是否丢弃最后不足的批次
        self.drop_last_eval = False # 评估时是否丢弃最后不足的批次
        self.num_workers = 4
        self.persistent_workers = True
        self.pin_memory = True
        self.prefetch_factor = 1 # 仅当 num_workers > 0 时有效

        # --- Optuna 集成 (仅在 trial 对象存在时执行) ---
        if trial:
            trial.set_user_attr("random_state", self.random_state)
            if self.dataset_name:
                trial.set_user_attr("dataset_name", self.dataset_name)
            if self.split_strategy:
                trial.set_user_attr("split_strategy", self.split_strategy)
            if self.split_dataset_assignments:
                trial.set_user_attr("split_dataset_assignments", self.split_dataset_assignments)
            if self.allowed_activity_labels:
                trial.set_user_attr("allowed_activity_labels", self.allowed_activity_labels)

            trial.set_user_attr("target_sample_rate", self.target_sample_rate)
            trial.set_user_attr("original_sample_rate", self.original_sample_rate)
            trial.set_user_attr("num_channels", self.num_channels)
            trial.set_user_attr("window_size", self.window_size)
            trial.set_user_attr("stride", self.stride)
            trial.set_user_attr("min_segment_length", self.min_segment_length)

            if self.num_classes is not None:
                 trial.set_user_attr("num_classes", self.num_classes)

            trial.set_user_attr("split_ratio", self.split_ratio) # 即使是硬编码也记录一下
            trial.set_user_attr("batch_size", self.batch_size)
            trial.set_user_attr("lr", self.lr)
            # 可以选择性地记录其他硬编码参数


class MoVQFormerConfig(BaseConfig):
    """
    MoVQFormer 模型的特定配置，继承基础设置 (保留硬编码值)。
    """
    def __init__(self, trial, model_name, options, devices=None, accelerator="gpu"):
        # 首先初始化基类配置
        super().__init__(trial, model_name, options, devices, accelerator)

        self.model = MoVQFormer.l_model
        self.data_module = MoVQFormerDataModule

        # --- MoVQFormer config ---
        self.mode = options.get('mode', None)
        self.block_size = 50
        self.codebook_size = 1024
        self.embedding_dim = 256 # VQ 和 Transformer 内部维度
        self.transformer_nhead = 8
        self.transformer_dim_feedforward = 1*self.embedding_dim
        self.transformer_dropout = 0
        self.transformer_num_layers = 5
        self.transformer_activation = "gelu"

        self.vq_dim = 256 # 明确指定 VQ 维度，应与 embedding_dim 匹配
        self.text_embedding_dim = 768 # 文本嵌入维度 (如果使用)
        self.max_sensors = 256 # 最大传感器数 (可能与 SensorConfigLoader 相关)
        self.mask_ratio = 0.25 # 掩码比例 (如果用于预训练)

        self.weight_decay = 1e-5
        self.freeze_encoder = options.get('freeze_encoder', False)
        self.encoder_lr_factor = options.get('encoder_lr_factor', 1e-1)

        # --- Optuna 集成 (仅在 trial 对象存在时执行) ---
        if trial:
            # 记录 MoVQFormer 特定的硬编码超参数
            trial.set_user_attr("block_size", self.block_size)
            trial.set_user_attr("codebook_size", self.codebook_size)
            trial.set_user_attr("embedding_dim", self.embedding_dim)
            trial.set_user_attr("transformer_nhead", self.transformer_nhead)
            trial.set_user_attr("transformer_dim_feedforward", self.transformer_dim_feedforward)
            trial.set_user_attr("transformer_dropout", self.transformer_dropout)
            trial.set_user_attr("transformer_num_layers", self.transformer_num_layers)
            trial.set_user_attr("vq_dim", self.vq_dim)
            trial.set_user_attr("mask_ratio", self.mask_ratio)
            trial.set_user_attr("text_embedding_dim", self.text_embedding_dim)
            trial.set_user_attr("max_sensors", self.max_sensors)
            trial.set_user_attr("transformer_activation", self.transformer_activation)





class TSBMConfig(BaseConfig):
    """
    Configuration class for the TSBM model.

    Attributes:
        block_size (int): Size of each block for segmentation.
        overlap_ratio (float): Overlap ratio between consecutive blocks.
        sequence_length (int): Length of the sequence after unfolding.
        embedding_dim (int): Dimensionality of the embedding space.
        codebook_size (int): Size of the codebook for quantization.
        decay (float): Decay rate for commitment loss.
        commitment_weight (float): Weight for the commitment loss term.
        num_attention_heads (int): Number of attention heads in the model.
        num_decoder_layers (int): Number of layers in the decoder.
    """

    def __init__(self, trial, model_name, options, devices=None, accelerator="gpu"):
        super().__init__(trial, model_name, options, devices, accelerator)
        self.model = ts.l_model
        # self.data_module = TSBMDataModule
        self.calculate_stats_n_jobs = os.cpu_count() - 2
        self.stats_adj_threshold = 0.95

        # Model-specific parameters for TSBM
        self.block_size = 50
        trial.set_user_attr("block_size", self.block_size)

        # Calculate num_blocks
        self.num_blocks = self.window_size // self.block_size
        # Validate if num_blocks is an integer
        # Check if window_size is divisible by block_size
        if self.window_size % self.block_size != 0:
            raise ValueError(
                f"'num_blocks' must be an integer. Ensure 'window_size' ({self.window_size}) "
                f"is divisible by 'block_size' ({self.block_size})."
            )
        # Set num_blocks as a user attribute for the trial
        trial.set_user_attr("num_blocks", self.num_blocks)

        self.overlap_ratio = 0
        trial.set_user_attr("overlap_ratio", self.overlap_ratio)

        self.embedding_dim = 30  # dsads 11 mhealth 21 pamap2 318 18

        self.num_stats_features = 783
        self.codebook_size = 512
        trial.set_user_attr("codebook_size", self.codebook_size)

        self.commitment_weight = 1
        trial.set_user_attr("commitment_weight", self.commitment_weight)

        self.num_attention_heads = 2
        trial.set_user_attr("num_attention_heads", self.num_attention_heads)

        self.num_decoder_layers = 2
        trial.set_user_attr("num_decoder_layers", self.num_decoder_layers)

        config_loader = SensorConfigLoader()
        global_mappings = config_loader.global_mappings
        self.num_body_part = len(global_mappings["body_parts"]) + 1
        trial.set_user_attr("num_body_part", self.num_body_part)
        self.num_sensor = len(global_mappings["sensors"]) + 1
        trial.set_user_attr("num_sensor", self.num_sensor)
        self.num_axis = len(global_mappings["axes"]) + 1
        trial.set_user_attr("num_axis", self.num_axis)

        self.meta_embed_dim = 16
        trial.set_user_attr("meta_embed_dim", self.meta_embed_dim)


class TimesNetConfig(BaseConfig):
    """
    Configuration class for the TimesNet model.

    Attributes:
        task_name (str): The name of the task (e.g., 'classification').
        pred_len (int): Prediction length (not used in classification).
        num_kernels (int): Number of kernels for convolutional layers.
        embed (str): Type of embedding.
        freq (str): Frequency type for temporal embedding.
        num_class (int): Number of activity classes.
        seq_len (int): Input sequence length.
        enc_in (int): Input channels for encoder.
        d_model (int): Dimension of model layers.
        d_ff (int): Dimension of feedforward layers.
        e_layers (int): Number of encoder layers.
        top_k (int): Number of top kernels for attention.
        dropout (float): Dropout rate.
    """

    def __init__(
        self, trial, model_name, dataset_name, devices=None, accelerator="gpu"
    ):
        super().__init__(trial, model_name, dataset_name, devices, accelerator)
        self.model = TimesNet.l_model
        self.data_module = CommonDataModule

        # Model-specific parameters for TimesNet
        self.task_name = "classification"
        self.pred_len = 0
        self.num_kernels = 6
        self.embed = "timeF"
        self.freq = "s"
        self.num_class = self.num_classes
        self.seq_len = self.window_size
        self.enc_in = self.dataset_options[self.dataset_name]["num_channels"]
        self.d_model = 16
        self.d_ff = 32
        self.e_layers = 3
        self.top_k = 5
        self.dropout = 0.1


class ResNetConfig(BaseConfig):
    """
    Configuration class for the ResNet model.

    Attributes:
        model (callable): The ResNet model to use.
        data_module (callable): The data module to use for loading datasets.
    """

    def __init__(
        self, trial, model_name, dataset_name, devices=None, accelerator="gpu"
    ):
        super().__init__(trial, model_name, dataset_name, devices, accelerator)
        self.model = ResNet.l_model
        self.data_module = CommonDataModule
