from dataclasses import dataclass, field, fields
from typing import Optional

@dataclass
class Config:
    experiment_name: Optional[str] = "default"
    root_dir: Optional[str] = "./"
    device: Optional[str] = "cuda"  # cuda or cpu
    use_default_args: Optional[bool] = False # will depend on model
    save_at_epochs:Optional[list] = field(default_factory=list) 
    keep_full_precision:Optional[list] = field(default_factory=list)
        
    # training parameters
    epochs: Optional[int] = 120
    clip_grad_norm: Optional[float] = -1 # -1 means no clipping
    seed: Optional[int] = 42
    
    # data loader parameters
    train_shuffle: Optional[bool] = True
    test_shuffle: Optional[bool] = False
    training_batch_size: Optional[int] = 128
    test_batch_size: Optional[int] = 256
    num_workers: Optional[int] = 4
    pin_memory: Optional[bool] = True

    # scheduler parameters
    scheduler_type: Optional[str] = "none" # multi_step or none
    scheduler_gamma: Optional[float] = 0
    scheduler_milestones: Optional[list] = field(default_factory=list)

    # optimizer parameters
    optimizer_type: Optional[str] = "sgd" # sgd or adam
    optimizer_lr: Optional[float] = 0.1    
    optimizer_momentum: Optional[float] = 0.9
    optimizer_weight_decay: Optional[float] = 0.0
    optimizer_no_decay: Optional[list] = field(default_factory=list)

    # dataset parameters
    dataset_name: Optional[str] = "cifar10" # cifar10, cifar100, tinyimagenet
    dataset_root: Optional[str] = "data"
    dataset_valid_size: Optional[float] = 0.1
    dataset_num_classes: Optional[int] = 10
    dataset_order_seed: Optional[int] = 42

    # Model parameters
    model_name: Optional[str] = "cifar_resnet_20" # cifar_resnet_D_W, cifar_dsresnet_D_W
    model_compression_strategy: Optional[str] = "none" # birealnet
    model_required_bugdet: Optional[float] = 1.0

    def print_config(self):
        print("="*50+"\nConfig\n"+"="*50)
        for field in fields(self):
            print(field.name.ljust(30), getattr(self, field.name))
        print("="*50)

    def save(self, root_dir):
        path = root_dir + "/config.txt"
        with open(path, "w") as f:
            f.write("="*50+"\nConfig\n"+"="*50 + "\n")
            for field in fields(self):
                f.write(field.name.ljust(30) + ": " + str(getattr(self, field.name)) + "\n")
            f.write("="*50)