import torch

class DefaultConfig:
    def __init__(
        self,
        device=None,
        dtype=torch.float32,
        coeff_dec: float = 1.0,
        latent_dim: int = 128,   # dùng cho transformer emb_dim
        emb_dim: int = 128,
        num_heads: int = 4,
        num_layers: int = 3,
        qkv_dim: int = 512,
        mlp_dim: int = 512,
        attention_dropout_rate: float = 0.1,
        batch_size: int = 32,
        input_dim: int = 2,
        n_points: int = 100,
        lr: float = 1e-4,
        epochs: int = 10,
        threshold: float = 0.5,
        n_samples: int = 5000,
        scale_out: bool = True,
        min_val: float = -1.0,
        max_val: float = 1.0,
        decay_steps=2000
    ):
        # Device setup
        self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = dtype
        # Decoder loss weight
        self.coeff_dec = coeff_dec
        # Transformer configs
        self.emb_dim = emb_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.qkv_dim = qkv_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout_rate = attention_dropout_rate
        # Data/model shapes
        self.input_dim = input_dim
        self.n_points = n_points
        # Training
        self.batch_size = batch_size
        self.lr = lr
        self.epochs = epochs
        # Data preprocess
        self.threshold = threshold
        self.n_samples = n_samples
        # Output scaling
        self.scale_out = scale_out
        self.min_val = min_val
        self.max_val = max_val

        self.decay_steps=decay_steps

    def summary(self):
        print("=== Wormhole/Transformer Config ===")
        for k, v in self.__dict__.items():
            print(f"{k}: {v}")