import ml_collections
from jax import random


def get_config():

    config = ml_collections.ConfigDict()


    config.wandb = wandb = ml_collections.ConfigDict()
    wandb.project = "peclet_flux"
    wandb.name = "peclet_flux"
    wandb.tag = None

    
    config.Npou=8
    config.actual_Npou=config.Npou+2
    config.num_nodes=(128,)
    config.num_fields=1
    config.num_dimensions=1
    config.extra_params=1
   
    config.width=1.
    
    config.h_x=config.width/(config.num_nodes[0]-1)
    
    config.nu_dim=(16,)+(1,) #16 is a random batch size for init the model
    config.coords_dim=((config.num_nodes[0]-2),)+(config.num_dimensions,)
    config.seed=42


    config.dataset=dataset=ml_collections.ConfigDict()
    dataset.train_num_samples=90
    dataset.val_num_samples=10
    dataset.train_sub_sampling_rate=1
    
    #dataset.path='/home/pgkallin/whitney_jax/transformer/darcy_flow.h5'
    
    
    key,config.model_key=random.split(random.PRNGKey(config.seed))
    key,config.flux_key=random.split(key)
    key,config.psi_key,=random.split(key)
    

    config.model = model = ml_collections.ConfigDict()

    model.patch_size = (128,)
    model.grid_size = (128,)
    model.latent_dim = 256
    
    model.emb_dim = 256
    model.depth = 0
    model.num_heads = 8
    model.mlp_ratio= 1
    
    model.dec_emb_dim = 256
    model.dec_num_heads = 8
    model.dec_depth = 2
    
    model.num_mlp_layers= 1
  
    model.out_dim = config.Npou
    
    model.embedding_type = "grid"
    model.eps = 1e5
    model.layer_norm_eps= 1e-5

    model.exp_dim = config.extra_params
    model.exp_depth = 1

    #model.cond_enc="no_cond"
  
    
    config.flux = flux = ml_collections.ConfigDict()
    
    flux.Npou=config.actual_Npou
    flux.extra_params=config.extra_params
    flux.num_dimensions= 1
    flux.num_fields=config.num_fields
    flux.num_hidden_layers = 2
    flux.hidden_layer_width = 32
    flux.oriented_areas=True
    flux.use_norm=True
    flux.use_res=True
    flux.layer_norm_eps=1e-5


    config.lr = lr = ml_collections.ConfigDict()
    
    lr.peak_value = 1e-4
    lr.end_value = 1e-6
    lr.decay_rate = 0.95
    lr.transition_steps = 1000
    lr.warmup_steps = 250
    lr.init_value = 0.0


    config.opt = opt = ml_collections.ConfigDict()
    opt.weight_decay = 1e-4
    opt.clip_norm = 1.0


    config.flux_lr=flux_lr=ml_collections.ConfigDict()

    flux_lr.peak_value = 1e-4
    flux_lr.end_value = 1e-6
    flux_lr.decay_rate = 0.95
    flux_lr.transition_steps = 1000
    flux_lr.warmup_steps = 250
    flux_lr.init_value = 0.0

    config.flux_opt=flux_opt=ml_collections.ConfigDict()
    flux_opt.weight_decay=1e-5
    flux_opt.clip_norm = 1.0

    config.training=training=ml_collections.ConfigDict()
    training.batch_size=90
    training.val_batch_size=10
    training.num_steps=150000
   
    training.save_folder="/path/to/weights/"
    training.steps_per_save=2500
    training.log_loss=250
    training.save_flag=0
    training.flux_penalty=0.08

    config.newton=newton=ml_collections.ConfigDict()
    newton.tol=1e-12
    newton.num_iters=50
    newton.key=key
    
    return config