from dataclasses import dataclass

import torch
from src.logger.logger_utils import LoggerType


@dataclass
class LLMConfig:
    enc_in: int = 7
    d_model_e: int = 16
    patch_size: int = 16
    stride: int = 8
    block_size: int = 128
    n_encoder_layers: int = 2
    n_decoder_layers: int = 2
    n_heads: int = 8
    d_model: int = 128
    dropout: float = 0.2
    bias: bool = False
    device: str = 'cpu'
    is_causal: bool = False
    return_attention: bool = False
    input_type: str = 'numeric'
    has_encoder: bool = True
    patch_num: int = -1
    warmup_epochs: int = 2

    embed: str = 'timeF'
    features: str = 'M'
    target: str = 'OT'
    freq: str = 'd'
    learning_rate: float = 0.0001
    label_len: int = 0
    seq_len: int = 366
    pred_len: int = 96
    batch_size: int = 16
    checkpoints: str = './checkpoint.pth'
    train_epochs: int = 10
    data: str = ''
    data_path: str = ''
    root_path: str = ''
    num_workers: int = 0
    patience: int = 10
    lradj: str = 'type1'

    logging: LoggerType = LoggerType.MLFLOW

    def __post_init__(self):
        if torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'

