# config.py
from dataclasses import dataclass
from typing import Tuple, Optional, List, Union
import os


@dataclass
class TrainingConfig:
    # Basic parameters
    h: int
    look_back_window: int
    T: int
    batch_size: int
    max_scale: float
    model_name: str
    epochs: int

    # Model parameters
    norm_insequence: bool = True
    model_ada: bool = False
    model_path: Optional[str] = None
    dropout: float = 0.1
    tap: int = 3
    ta: float = 0.9
    features: str = 'S'
    optimizer: str = 'Adam'
    num_workers: int = 0
    learning_rate: float = 2e-4
    out_channels: int = 8
    root_path: Optional[str] = None
    d_norm: float = 1
    weight_decay: float = 0
    model_size: int = 4
    name: str = 'VR'
    downsample_factor: int = 16
    model_extend: bool = False
    one_channel: bool = False
    device_num: int = 0

    # Training parameters
    cycle_time: int = 5
    deep_cycle_time: int = 1
    change_epoch: int = 5
    kernel_size: int = 31
    focus: int = 2
    memory_test_mode: bool = False
    fft: bool = False
    temperature: float = 1
    aspp: bool = True
    low_level: bool = True
    load_history: bool = False
    scal: float = 1

    # Loss and evaluation
    loss_type: str = 'EMD'
    patch_size: Tuple[int, int] = (4, 32)
    gpu_ids: Optional[List[int]] = None
    curve: bool = False
    more_random: bool = False
    mu_norm: float = 1
    expand_data: int = 1
    set_syn_data_type: str = 'all'
    random_cut: bool = False
    early_stop: int = 50
    size: Optional[List[int]] = None
    flag: str = 'train'
    target: str = 'OT'

    def __post_init__(self):
        if self.size is None:
            self.size = [self.look_back_window, 0, self.T]

        if self.model_name == 'SwinUnet':
            from transformer_config_my import get_defaultconfig
            config = get_defaultconfig()
            config.DATA.IMG_SIZE = (self.size[0] + self.size[-1], self.h)
            config.DATA.BATCH_SIZE = self.batch_size
            config.MODEL.SWIN.EMBED_DIM = int(96 * self.model_size)
            self.transformer_config = config

    @classmethod
    def from_dict(cls, config_dict):
        return cls(**{k: v for k, v in config_dict.items()
                      if hasattr(cls, k)})

    def to_dict(self):
        return {k: v for k, v in self.__dict__.items()
                if not k.startswith('_')}