import ml_collections
from jax import random

def get_config():

    config = ml_collections.ConfigDict()

    config.wandb = wandb = ml_collections.ConfigDict()
    wandb.project = "pois_flux" #wandb project name
    wandb.name = "pois_flux" #wandb run name
    

    #problem  configs
    config.Npou=8 #partition of unity size
    config.actual_Npou=config.Npou #partition of unity including boundary partitions
    config.num_nodes=(128,128) #number of nodes in the original partition
    config.num_fields=1 #dimensionality of the codomain of the solution function
    config.num_dimensions=2 #dimensionality of the solution function
    config.extra_params=2 # size of the conditioning z in the flux neural network
    config.tensor=3 #size of a problem specific tensor related to global params of the PDE (Experiments 3, 4 from the paper) 


    config.width=1. #size of the rectangular domain 
    config.height=1.
    
    config.h_x=config.width/(config.num_nodes[0]) #size of the square celles of the original mesh
    config.h_y=config.height/(config.num_nodes[1])
    
    #dimensions of the inputs for flax' lazy initialization

    config.force_dim=(16,)+config.num_nodes+(1,) #16 is a random batch size for init the model
    config.tensor_dim=(16,config.tensor)
    config.coords_dim=((config.num_nodes[0])*(config.num_nodes[1]),)+(config.num_dimensions,)
    config.seed=42

    #dataset configs
    config.dataset=dataset=ml_collections.ConfigDict()
    dataset.train_num_samples=32768
    dataset.val_num_samples=4096
    dataset.train_sub_sampling_rate=1 #subsampling of the training set to create lower data regimes
    dataset.measure=(config.h_x,config.h_y)# term pertaining to the normalization of global parameters 
    dataset.train_path='/path/to/data'
    dataset.val_path='/path/to/data'
    
    #randomness initialization- jax specific
    key,config.model_key=random.split(random.PRNGKey(config.seed))
    key,config.flux_key=random.split(key)
    key,config.psi_key,=random.split(key)
    
    #cvit configs
    config.model = model = ml_collections.ConfigDict()

    model.patch_size = (8, 8)
    model.grid_size = (128, 128)
    model.latent_dim = 512
    
    #cvit encoder configs
    model.emb_dim = 512
    model.depth = 10
    model.num_heads = 16
   
    
    #cvit decoder configs
    model.dec_emb_dim = 512
    model.dec_num_heads = 16
    model.dec_depth = 1
    

    model.num_mlp_layers= 1
    model.layer_norm_eps= 1e-5
    model.mlp_ratio= 1
    
    #output size on each node= partition of unity
    model.out_dim = config.Npou
    
    #coordinate embedding configs 
    model.embedding_type = "grid"
    model.eps = 1e5
    
 
    #conditioning config for the flux network
    model.exp_dim = config.extra_params
    model.exp_depth = 2
    
    #discrete variable denoting the existence of global parameters as conditionings
    model.cond_enc="cond"
  
    #most probably irrelevant
    """
    config.in_layers=1
    config.adapter_layers=2
    config.adapter_hidden_dim=32
    config.adapter_params=config.extra_params
    """
    #flux network configs/ simple MLP 
    config.flux = flux = ml_collections.ConfigDict()
    
    flux.Npou=config.actual_Npou
    flux.extra_params=config.extra_params
    flux.num_dimensions= 1
    flux.num_fields=config.num_fields
    flux.num_hidden_layers = 2
    flux.hidden_layer_width = 32
    flux.oriented_areas=True #use of oriented areas
    flux.use_norm=True #rms normalization existence 
    flux.use_res=True  #residual connection existence
    flux.layer_norm_eps=1e-5

    # exp. scheduler configs
    config.lr = lr = ml_collections.ConfigDict()
    
    lr.peak_value = 1e-4
    lr.end_value = 1e-6
    lr.decay_rate = 0.95
    lr.transition_steps = 2500
    lr.warmup_steps = 2500
    lr.init_value = 0.0

    #adamW configs
    config.opt = opt = ml_collections.ConfigDict()
    opt.weight_decay = 1e-5
    opt.clip_norm = 1.0

    # exp. scheduler configs for the flux NN
    config.flux_lr=flux_lr=ml_collections.ConfigDict()

    flux_lr.peak_value = 1e-4
    flux_lr.end_value = 1e-6
    flux_lr.decay_rate = 0.95
    flux_lr.transition_steps = 2500
    flux_lr.warmup_steps = 2500
    flux_lr.init_value = 0.0

    #adamW configs for the flux NN
    config.flux_opt=flux_opt=ml_collections.ConfigDict()
    flux_opt.weight_decay=1e-5
    flux_opt.clip_norm = 1.0

    #training configs
    config.training=training=ml_collections.ConfigDict()
    training.batch_size=16
    training.val_batch_size=16
    training.num_steps=int(1e6)//4
    training.save_folder="/path/to/weights/"
    training.steps_per_save=200
    training.log_loss=200
    training.save_flag=0
    training.flux_pen=0.000 #flux penalty lambda
    
    #interior newton method configs
    config.newton=newton=ml_collections.ConfigDict()
    newton.tol=1e-8
    newton.num_iters=50
    newton.key=key
    
    return config
