###############################################################
########## INSTRUCTIONS TO INITIALIZE DASH OPTIMIZER ##########
###############################################################

from dash import *

opt = DASH(
    model.parameters(),# initialize your model above
    lr=1e-3,
    weight_decay=0,
    config=ShampooConfig(
        # next 3 params are only used to update layer norms using AdamW if algo_one_dim=AlgoOneDim.ADAMW
        adamw_eps=1e-8, 
        adamw_beta1=0.9,
        adamw_beta2=0.95,

        beta_G=0.9, # momentum for the gradient
        beta_LR=0.95, # EMA decay for L and R preconditioners
        beta_graft=0.95, # EMA decay for grafting buffer A

        eps_inv_root=0, # use eps = 0 for NDB
        inv_root_method=InverseRootMethodType.from_string('ndb'),
        inv_root_freq=1, # keep fixed for NDB

        grafting_type=GraftingType.from_string('adam'), # the only option for now
        eps_grafting=1e-8, # keep fixed

        mu=0, # a value different than 0 might lead to divergence, identical to Distributed Shampoo
        use_nesterov=False, # turn on if mu > 0
        use_bias_correction=True, # keep fixed

        start_prec_step=-1, # set to -1 to do preconditioning from the very first step
        block_size=1024, # this is B. Given embedding size E, set this to value B such B <= E and make sure E % B = 0
        matmul_dtype=torch.float32, # use fp32 for NDB and torch.float16 for CN

        matrix_scaling_type=MatrixScalingType.from_string('pim'),
        matrix_scaling_pi_steps=10, # do not change this
        matrix_scaling_const=2, # do not change this

        newton_steps = 10, # used for NewtonDB and CoupledNewton, keep fixed
        algo_one_dim=AlgoOneDim.from_string('shmp'), # keep fixed

        ### EVD
        evd_heuristic=EVDHeuristic.from_string('shmp'), # keep fixed

        ### CN
        cn_tolerance=1e-6, # keep fixed

        ### CBSHV
        cbshv_degree=60, # can be changed
    )
)