from dataclasses import dataclass

@dataclass
class BasicConfig:
    device: str = "cuda"
    num_workers: int = 2  
    seed: int = 42

    init_mode: str = "uniform" 
    
    check_collapse_every: int = 10 
    ckp_dir: str = "./ckp/gmm_ckp"
    

@dataclass
class Config(BasicConfig):
    exp_name: str = "Default_Exp"

    dataset: str = "cifar10"
    data_root: str = "./dataset"
    resize: bool = False

    arch: str = "resnet18"
    clf_ckpt: str = "./model_zoo/trained_model/resnet18_cifar10.pth"

    K: int = 7
    latent_dim: int = 128

    cond_mode: str | None = 'xy'
    cov_type: str = "full"
    cov_rank: int = 0
    hidden_dim: int = 512

    use_y_embedding: bool = True
    y_emb_dim: int = 128
    y_emb_normalize: bool = True

    use_decoder: bool = True
    decoder_backend: str = "bicubic_trainable"

    norm: str = 'linf'
    epsilon: float = 16/255

    epochs: int = 2
    batch_size: int = 512
    batch_index_max: int = float("inf")

    lr: float = 5e-4
    weight_decay: float = 0.0
    grad_clip: float = 5.0
    accumulate_grad: int = 1

    use_lr_scheduler: bool = False
    lr_warmup_epochs: int = 20
    lr_min: float = 2e-6

    loss_variant: str = "cw"
    kappa: float = 1

    num_samples: int = 32
    chunk_size: int = 32

    reg_pi_entropy: float = 0.0
    reg_mean_div: float = 0.0

    T_pi_init: float = 3.0
    T_pi_final: float = 1.0

    T_mu_init: float = 3.0
    T_mu_final: float = 1.0

    T_sigma_init: float = 1.5
    T_sigma_final: float = 1.0

    T_shared_init: float = 1.5
    T_shared_final: float = 1.0
    warmup_epochs: int = 50

    use_gumbel_anneal: bool = True
    gumbel_temp_init: float = 1.0
    gumbel_temp_final: float = 0.1


    def to_dict(self):
        return self.__dict__.copy()

    def __repr__(self):
        lines = ["Configuration:"]
        lines.append("=" * 60)
        for key, value in self.__dict__.items():
            lines.append(f"  {key:25s}: {value}")
        lines.append("=" * 60)
        return "\n".join(lines)


def get_config(name: str = "debug") -> Config:
    configs = {

        "resnet18_on_cifar10": Config(
            exp_name = "K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none)",

            lr = 5e-4,
            use_lr_scheduler = False,
            warmup_epochs = 10,

            K = 7,
            latent_dim = 128,

            cond_mode = "xy",
            cov_type = "full",
            cov_rank = 0,
            hidden_dim = 256,

            use_y_embedding = True,
            y_emb_dim = 64,
            y_emb_normalize = True,

            use_decoder = True,
            decoder_backend = 'bicubic_trainable',

            norm = "linf",
            epsilon = 16/255,

        ),


    }
    
    if name not in configs:
        raise ValueError(
            f"Unknown config: '{name}'\n"
            f"Available: {list(configs.keys())}"
        )
    
    return configs[name]


def list_configs():
    """List all available configs."""
    print("\nAvailable configurations:")
    print("=" * 60)
    print("  resnet18_on_cifar10(cuda0) - K=1, unconditional, no decoder, linf epsilon=4/255")
    print("=" * 60)