import torch
import os

class Config:
    # -------------------------- Basic Env Config --------------------------
    DEVICE = "cuda"
    NUM_WORKERS = 4
    BATCH_SIZE = 64

    # -------------------------- Data Config --------------------------
    PROCESSED_DATA_DIR = "./data"
    DATASET_NAME = "maze5"
    MAX_PUZZLES = 2500
    TRAIN_SPLIT = 0.8
    VAL_SPLIT = 0.1

    # -------------------------- Path Config --------------------------
    RESULT_DIR = "./maze5-2k"
    VIS_DIR = os.path.join(RESULT_DIR, "vis_tlad_dynamics")
    SAVE_DIR = os.path.join(RESULT_DIR, "checkpoints")
    LOG_DIR = os.path.join(RESULT_DIR, "logs")

    # -------------------------- Base Model Architecture --------------------------
    GRID_SIZE = 11
    SEQ_LEN = 121
    INPUT_CLASSES = 4
    OUTPUT_CLASSES = 2
    NUM_HEADS = 4
    NUM_LAYERS_S1 = 4
    HIDDEN_DIM = 128

    # -------------------------- Loss Weight & Dynamics Engine --------------------------
    DICE = 2.0
    DEGREE = 0.1
    
    EBA_STEP_SIZE = 1.0
    EBA_MOMENTUM = 0.9
    TEMP_START = 5.0
    TEMP_END = 0.2

    W_TASK = 2.5
    LIT_THRESHOLD = 0.5     

    W_WALL = 5.0
    W_ENDPOINT = 5.0
    W_INVALID = 25.0
    W_DEGREE = 25.0
    W_ENTROPY = 0.05

    # -------------------------- Training Strategy --------------------------
    LR_PHASE1 = 1e-3
    LR_PHASE2 = 1e-4

    # -------------------------- Tunable Core Parameters --------------------------
    EPOCHS_PHASE1 = 10
    EPOCHS_PHASE2 = 30
    EBA_STEPS = 24
    TRANSITION_CENTER = 0.8
    ANNEALING_SLOPE = 2.0
    LAMBDA_INIT = 20
    LAMBDA_MAX = 20

    # -------------------------- Utility --------------------------
    @classmethod
    def to_dict(cls):
        return {
            k: v for k, v in cls.__dict__.items()
            if not k.startswith('__') and not callable(getattr(cls, k))
        }