from enum import Enum
import torch

#############################################################################################
##################################### DATA TYPES ############################################
#############################################################################################
class MatType(str, Enum):
    # Random
    DIAG_UNIF_1BALL = 1 # simply select diagonal entries from uniform of given bounds
    DIAG_UNIF_BOUNDARY_1BALL = 2 # simply select diagonal on near the boundary of unit ball
    UNIF_ORTHOGRP = 3 # orthogonal of given dimension from orthogroup sampling
    SYMM_1BALL_UNIF_ORTHOGRP = 4 # Symmetric matrix from eigenvalue decomp where eigenvectors are from orthogroup and 
    PSD_1BALL_UNIF_ORTHOGRP = 5
    DIAG_WITH_UNIT_EIGVALS = 6
    DIAG_WITH_EIGVALS_OUTSIDE_BALL = 7
    DIAG_NONDIAG_UNIF_1BALL = 10

    # Deterministic
    ID = 8
    TWO_ID = 9
    DETERMINISTIC_ONES = 11


class AugmType(str, Enum):
    # Random
    TF_GD = 1
    MESA_OPT = 2


############################################################################################
##################################### CONSTANTS ############################################
############################################################################################

AUG_TOK_EXTRA_POS = { AugmType.TF_GD: 0, AugmType.MESA_OPT: 1 }

REL_PATH_DATA_FLDR = "/saved_models"
REL_PATH_EXP_FLDR = "/exp_results"

OUT_DIM = 1

cfg_randomness = {
                "rand_seed": 666013,
                "multi_rand_seeds": False,
                "test_rand_seed_number": 1000,
                "rand_multi_seeds": [666013, 1, 0]
                }


cfg_dyn_sys = {
                ############# DYNAMICAL SYSTEM
                "state_dim": 5,
                "output_dim": OUT_DIM,
                "noise_var": 1e-2,
                "seq_len": 30,
                "A_type": MatType.DIAG_NONDIAG_UNIF_1BALL, #MatType.DIAG_UNIF_1BALL, #MatType.DIAG_UNIF_BOUNDARY_1BALL, #MatType.PSD_1BALL_UNIF_ORTHOGRP, #MatType.SYMM_1BALL_UNIF_ORTHOGRP, #MatType.UNIF_ORTHOGRP, #MatType.DIAG_UNIF_BOUNDARY_1BALL,
                "C_type": MatType.DETERMINISTIC_ONES, #MatType.DIAG_UNIF_BOUNDARY_1BALL, #MatType.PSD_1BALL_UNIF_ORTHOGRP, #MatType.SYMM_1BALL_UNIF_ORTHOGRP, #MatType.UNIF_ORTHOGRP, #MatType.DIAG_UNIF_BOUNDARY_1BALL,
                "single_sys": False,
                "is_diag": False,
                "shuffle": False,
            }


cfg_tok = {
            ############# AUGMENTED TOKENS' CONFIG
            "aug_type": AugmType.TF_GD, #None, #AugmType.MESA_OPT, #
            "wdw_size": 1,
}


cfg_transformer = {
                ############# TRANSFORMER CONFIG
                "model_dim": (cfg_tok["wdw_size"] + 1 + AUG_TOK_EXTRA_POS[cfg_tok["aug_type"]]) * OUT_DIM,
                "qk_dim": OUT_DIM,
                "io_layer_dim": OUT_DIM,
                "no_att_layers": 1,
                "n_heads": 1,
                "layer_norm": False,
                "lin_att": True,
                "extra_input_lin_layer": False,
                "extra_output_lin_layer": False,
                "projection": False,
                "pos_enc_type": None, #PosEnc.LEARNABLE_NON_AUG, 
                "att_init_scale": 1e-5,
                "device": "cuda" if torch.cuda.is_available() else "cpu",
}


cfg_optimizer = {
            ############# OPTIMIZER CONFIG
            "optimizer": "adamw",
            "b1": 0.9,
            "b2": 0.98,
            "eps": 1e-9,
            "lr": 5e-2, # if a scheduler is used, this will coincide with the max_lr attained through scheduling
            "wd": 0.005,
            "schedule_lr": True,
            "warmup_steps": 500,   # Warmup steps for learning rate scheduling
            "max_decay_steps": 2000,    # Maximum number of iterations for scheduling
            "end_lr": 1e-3,  # Final learning rate (at max_iters train steps) / minimal LR
            "loss": "mse",

}


cfg_training = {
            ############# TRAINING CONFIG
            "max_gr_nrm": 300,
            "batch_sz": 1000,
            "max_iters": 20000,
            "train_record_freq": 100,
            "freq_heatmap": 500,
            "camera_ready": True,
            "val_set_sz": 3000

}


cfg_plotting = {
            ############# PLOT CONFIG
            "plot_tube": False,
}

cfg = cfg_randomness | cfg_dyn_sys | cfg_tok | cfg_transformer | cfg_optimizer | cfg_training | cfg_plotting