'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  config.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
import random
from typing import Optional, Dict


DEFAULT_CONFIG = dict(
    data_path=None,
    data_seed=random.randint(0, 9999999),
    key_mapping={
        'R': 'positions',
        'C': 'cell',
        'Z': 'numbers',
        'E': 'energy',
        'F': 'forces',
        'N': 'n_atoms'
    },
    atomic_types=None,
    atomic_energies=None,
    r_cutoff=None,
    n_train=None,
    n_valid=None,
    train_batch_size=32,
    eval_batch_size=100,
    neighbors='matscipy',
    model_path=None,
    model_seed=random.randint(0, 9999999),
    device='cuda:0',
    default_dtype='float32',
    readout_MLP=[16],
    radial_MLP=[64, 64, 64],
    n_basis=8,
    n_polynomial_cutoff=5,
    n_hidden_feats=16,
    n_product_feats=16,
    n_interactions=2,
    l_max_hidden_feats=1,
    l_max_edge_attrs=3,
    correlation=3,
    coupled_product_feats=False,
    symmetric_product=True,
    avg_n_neighbors=1,
    compute_avg_n_neighbors=True,
    compute_regression_shift=True,
    # Training epochs and optimizer
    max_epoch=1000,
    save_epoch=100,
    valid_epoch=1,
    lr=0.01,
    lr_factor=0.8,
    scheduler_patience=50,
    amsgrad=True,
    max_grad_norm=None,
    weight_decay=5.0e-7,
    ema=True,
    ema_decay=0.99,
    # Training, early stopping, and evaluation losses
    train_loss={
        'type': 'weighted_sum',
        'losses': [
            {'type': 'energy_sse'},
            {'type': 'forces_sse'}
        ],
        'weights': [
            1.0,
            4.0
        ]
    },
    early_stopping_loss={
        'type': 'weighted_sum',
        'losses': [
            {'type': 'energy_mae'},
            {'type': 'forces_mae'}
        ],
        'weights': [
            1.0,
            1.0
        ]
    },
    eval_losses=[
        {'type': 'energy_rmse'},
        {'type': 'energy_mae'},
        {'type': 'forces_rmse'},
        {'type': 'forces_mae'}
    ]
)


def update_config(config: Optional[Dict] = None) -> Dict:
    """Updates the default config using parameters from the dictionary.

    Args:
        config (Optional[Dict], optional): Dictionary containing model/training configurations. 
                                           Defaults to None.

    Returns:
        Dict: Updated model/training configuration.
    """
    if config:
        # update default config
        DEFAULT_CONFIG.update(config)
    # set config to DEFAULT_CONFIG
    config = DEFAULT_CONFIG
    return config
