import ml_collections
from jax import random

def get_config():

    config = ml_collections.ConfigDict()


    config.wandb = wandb = ml_collections.ConfigDict()
    wandb.project = "ad_flux"
    wandb.name = "ad_flux"
    wandb.tag = None

    
    config.Npou=8
    config.actual_Npou=config.Npou+1
    config.num_nodes=(128,128)
    config.num_fields=1
    config.num_dimensions=2
    config.extra_params=2
    config.tensor=3

    
    config.width=1.
    config.height=1.
    
    config.h_x=config.width/(config.num_nodes[0]-1)
    config.h_y=config.height/(config.num_nodes[1]-1)
    

    config.force_dim=(16,)+config.num_nodes+(1,) #16 is a random batch size for init the model
    config.tensor_dim=(16,config.tensor)
    config.coords_dim=((config.num_nodes[0]-2)*(config.num_nodes[1])-2,)+(config.num_dimensions,)
    config.seed=42


    config.dataset=dataset=ml_collections.ConfigDict()
    dataset.train_num_samples=32768
    dataset.val_num_samples=4096
    dataset.train_sub_sampling_rate=1
    dataset.measure=(config.h_x,config.h_y)
    dataset.train_path='/path/to/data/'
    dataset.val_path='/path/to/data/'
    
    
    key,config.model_key=random.split(random.PRNGKey(config.seed))
    key,config.flux_key=random.split(key)
    key,config.psi_key,=random.split(key)
    

    config.model = model = ml_collections.ConfigDict()

    model.patch_size = (8, 8)
    model.grid_size = (128, 128)
    model.latent_dim = 256
    
    model.emb_dim = 256
    model.depth = 10
    model.num_heads = 8
    model.mlp_ratio= 1
    
    model.dec_emb_dim = 512
    model.dec_num_heads = 16
    model.dec_depth = 1
    
    model.num_mlp_layers= 1
  
    model.out_dim = config.Npou
    
    model.embedding_type = "grid"
    model.eps = 1e5
    model.layer_norm_eps= 1e-5

    model.exp_dim = config.extra_params
    model.exp_depth = 2
    
    model.cond_enc="cond"
  

    config.in_layers=1
    config.adapter_layers=2
    config.adapter_hidden_dim=32
    config.adapter_params=config.extra_params
    
   
    config.flux = flux = ml_collections.ConfigDict()
    
    flux.Npou=config.actual_Npou
    flux.extra_params=config.extra_params
    flux.num_dimensions= 1
    flux.num_fields=config.num_fields
    flux.num_hidden_layers = 2
    flux.hidden_layer_width = 32
    flux.oriented_areas=True
    flux.use_norm=True
    flux.use_res=True
    flux.layer_norm_eps=1e-5


    config.lr = lr = ml_collections.ConfigDict()
    
    lr.peak_value = 1e-4
    lr.end_value = 1e-6
    lr.decay_rate = 0.95
    lr.transition_steps = 2500
    lr.warmup_steps = 2500
    lr.init_value = 0.0


    config.opt = opt = ml_collections.ConfigDict()
    
    opt.weight_decay = 1e-5
    opt.clip_norm = 1.0

 
    config.flux_lr=flux_lr=ml_collections.ConfigDict()

    flux_lr.peak_value = 1e-4
    flux_lr.end_value = 1e-6
    flux_lr.decay_rate = 0.95
    flux_lr.transition_steps = 2500
    flux_lr.warmup_steps = 2500
    flux_lr.init_value = 0.0

    config.flux_opt=flux_opt=ml_collections.ConfigDict()
    flux_opt.weight_decay=1e-5
    flux_opt.clip_norm = 1.0

    config.training=training=ml_collections.ConfigDict()
    training.batch_size=16
    training.val_batch_size=16
    training.num_steps=int(1e6)//4
   
    training.save_folder="/path/to/weights"
    training.steps_per_save=200
    training.flux_pen=0.01
    training.log_loss=200
    training.save_flag=0

    config.newton=newton=ml_collections.ConfigDict()
    newton.tol=1e-10
    newton.num_iters=50
    newton.key=key
    
    return config