import ml_collections
from jax import random

def get_config():

    config = ml_collections.ConfigDict()


    config.wandb = wandb = ml_collections.ConfigDict()
    wandb.project = "darcy"
    wandb.name = "darcy"
    wandb.tag = None

    
    config.Npou=8
    config.actual_Npou=config.Npou+1
    config.num_nodes=(128, 128)
    config.num_fields=1
    config.num_dimensions=2
    config.extra_params=1 
    

    config.width=1.
    config.height=1.
    
    config.h_x=config.width/(config.num_nodes[0]-1)
    config.h_y=config.height/(config.num_nodes[1]-1)

    config.force=20
    

    config.nu_dim=(16,)+(config.num_nodes[0],config.num_nodes[0])+(1,) #16 is a random batch size for init the model
    config.coords_dim=((config.num_nodes[0]-2)*(config.num_nodes[1]-2),)+(config.num_dimensions,)
    config.seed=42


    config.dataset=dataset=ml_collections.ConfigDict()
    dataset.train_num_samples=9500
    dataset.val_num_samples=500
    dataset.train_sub_sampling_rate=1
    dataset.path='/path/to/data'
    
    
    key,config.model_key=random.split(random.PRNGKey(config.seed))
    key,config.flux_key=random.split(key)
    key,config.psi_key,=random.split(key)
    

    config.model = model = ml_collections.ConfigDict()

    model.patch_size = (8, 8)
    model.grid_size = (128, 128)
    model.latent_dim = 256
    
    
    model.emb_dim = 256
    model.depth = 8
    model.num_heads = 8
    model.mlp_ratio= 1
    
    
    model.dec_emb_dim = 512
    model.dec_num_heads = 16
    model.dec_depth = 1
    
    
    model.num_mlp_layers= 1
    model.out_dim = config.Npou
    
    
    model.embedding_type = "grid"
    model.eps = 1e5
    model.layer_norm_eps= 1e-5

    model.exp_dim = config.extra_params
    model.exp_depth = 1

    model.cond_enc="no_cond"
  
    
    config.flux = flux = ml_collections.ConfigDict()
    
    flux.Npou=config.actual_Npou
    flux.extra_params=config.extra_params
    flux.num_dimensions= 1
    flux.num_fields=config.num_fields
    flux.num_hidden_layers = 2
    flux.hidden_layer_width = 32
    flux.oriented_areas=True
    flux.use_norm=True
    flux.use_res=True
    flux.layer_norm_eps=1e-5


    config.lr = lr = ml_collections.ConfigDict()
    
    lr.peak_value = 1e-4
    lr.end_value = 1e-6
    lr.decay_rate = 0.95
    lr.transition_steps = 2000
    lr.warmup_steps = 2000
    lr.init_value = 0.0


    config.opt = opt = ml_collections.ConfigDict()
    opt.weight_decay = 1e-5
    opt.clip_norm = 1.0


    config.flux_lr=flux_lr=ml_collections.ConfigDict()

    flux_lr.peak_value = 1e-4
    flux_lr.end_value = 1e-6
    flux_lr.decay_rate = 0.95
    flux_lr.transition_steps = 2000
    flux_lr.warmup_steps = 2000
    flux_lr.init_value = 0.0

    config.flux_opt=flux_opt=ml_collections.ConfigDict()
    flux_opt.weight_decay=1e-5
    flux_opt.clip_norm = 1.0

    config.training=training=ml_collections.ConfigDict()
    training.batch_size=16
    training.val_batch_size=16
    training.num_steps=int(1e6)//4
   
    training.save_folder="/path/to/weights"
    training.steps_per_save=250
    training.log_loss=50
    training.flux_pen=0.00
    training.save_flag=0

    config.newton=newton=ml_collections.ConfigDict()
    newton.tol=1e-12
    newton.num_iters=50
    newton.key=key
    
    return config