from copy import deepcopy
from typing import Any, Callable, Dict, Literal, Tuple

import jax
import jax.profiler
import numpy as onp
import wandb
from jax import config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from deixc.dataset import DEIXCDataset
from deixc.scf import DerivativeInformedSelfConsistentFieldSolver
from deixc.training import final_evaluation, run
from deixc.training.loss import (
    DynamicLossConfig,
    StaticLossConfig,
)
from egxc import dataloading
from egxc.discretization import (
    get_grid_fn,
    get_gto_grid_eval_fn,
    get_gto_preloader,
)
from egxc.systems import PreloadSystem
from egxc.training.optimizer import OptConfig
from egxc.utils.checkpointing import CheckpointManager
from egxc.utils.logging import Logger
from egxc.utils.typing import Alignment, MethodKey, NnParams
from egxc.xc_energy import DensityFeatures, XCModule, functionals

compilation_cache.set_cache_dir('./caches/jax_compile/deixc')
# jax.config.update("jax_explain_cache_misses", True)

config.update('jax_platform_name', 'gpu')
config.update('jax_enable_x64', True)
config.update('jax_default_matmul_precision', 'float32')
# config.update('jax_debug_nans', True)

ex = Experiment()

if __name__ == '__main__':
    wandb.login()

CheckpointType = Literal['scratch', 'static_pretrain', 'static_train', 'dynamic_train']


@ex.config
def default_config():
    profile_device_memory = False
    run_seed = onp.random.default_rng().integers(int(1e6), dtype=onp.int32)
    logging = {
        'project': 'DEI-XC',
        'dir': 'ANONYMOUS_DIR',
        'run_name': None,
        'checkpointing': {
            'directory': None,
        },
        'benchmark': True,
    }
    base = {
        'seed': run_seed,
        'test': False,
        'use_density_fitting': True,
        'spin_restricted': True,
    }
    alignment = {
        'atom': 1,
        'basis': 1,
        'grid': 32 * 1024,
    }
    basis = {
        'name': 'def2-SVP',  # def2-TZVPD
        'derivative': 1,
    }
    quadrature = {
        'level': 1,
    }
    initial_density_guess = {
        'base_initial_density_guess_key': 'minao',  # 'minao', 'atom'
        'initial_ref_density_method': None,  # if None, no reference density is used
        'min_ref_density_interpolation': 0.0,
        'max_ref_density_interpolation': 0.0,
        'noise_eps': 0.0,
    }
    initial_model_params = {
        'load_from': None,
        'checkpoint_type': 'scratch',
        'load_as_local_params': False,
    }
    data = {
        'workers': 8,
        'worker_buffer_size': 1,
        'batch_size': 1,
        'shuffle': True,
        'data_split_seed': 0,  # random seed for splitting the dataset
        'n_test_samples': None,  # number of samples to use for testing, if None, use the full test set
        'preload': {
            'center': False,
        },
        'split': {'val_fraction': 0.1},
        'deixc_method_key': 'ks_dft',
        'align_scf_trajectory': 7,
        'shift_dispersion': False,
        'deixc_method_kwargs': {
            'xc_str': 'B3LYP',
            'basis': basis['name'],
            'backend': 'custom',  # TODO: check
            'use_eri_density_fitting': base['use_density_fitting'],
            'use_exchange_density_fitting': base['use_density_fitting'],
            'spin_restricted': base['spin_restricted'],
            'quadrature_grid_level': quadrature['level'],
        },
    }

    static_loss = {
        'energy': {
            'measure': 'mse',  # 'mae', 'mse', 'huber', 'asinh'
            'scale_per_electron': True,
            'scale_parameter': 0.1,  # Hartree; only relevant for huber and asinh
        },
        'xc_potential': {
            'measure': 'L2',
            'scale_per_electron': True,
            'per_sample_optimal_gauge': 'L2',
        },
        'orbital_rotation_gradient': {
            'measure': 'L2',
            'scale_per_entry': False,
            'oep_weighting': False,
        },
        'reference_basis_is_same': basis['name'] == data['deixc_method_kwargs']['basis'],
        'relative_weights': {
            'xc_energy': 1.0,
            'forces': 0.0,
            'xc_potential': 0.0,
            'orbital_rotation_gradient': 5e-4,
        },
        'vectorize_along_scf': {
            'xc_energy': 'egxc2024',
            'forces': 'egxc2024',
            'xc_potential': 'egxc2024',
            'orbital_rotation_gradient': 'egxc2024',
        },
    }
    static_optimizer = {
        'name': 'adam',  #'lamb',
        'additional_params': {'b1': 0.9, 'b2': 0.999},
        'weight_decay': 0.0,
        'decay_only_graph_readout': False,
        'apply_every': 1,
        'clip_grad_max_norm': None,
        'skip_nans': 1,
        'ema_decay': 0.995,
        'plateau_handling': {
            'factor': 0.25,
            'patience': 5,
            'cooldown': 5,
            'accumulation_size': 1000,
            'min_scale': 1e-12,
            'min_relative_improvement': 0.02,
        },
    }
    # Configuration for pretraining using precomputed densities (optional additional stage)
    static_pretrain = {
        'epochs': 1_000_000,
        'early_stopping_patience': 5,
        'early_stopping_min_relative_improvement': 0.01,
        'schedule': {
            'base_rate': 0.01,
            'warmup_schedule': 'linear',
            'warmup_steps': 1000,
            'decay_schedule': 'inverse_time_decay',
            'decay_steps': 3000,  # in case of inverse_time_decay this parametrizes the characteristic tim^e scale in terms of number of steps
            'min_rate': 0.0,  # has to be zero in case of inverse_time_decay
        },
    }
    static_train = {
        'epochs': 1_000_000,
        'early_stopping_patience': 5,
        'early_stopping_min_relative_improvement': 0.005,
        'schedule': {
            'base_rate': 0.001,
            'warmup_schedule': 'linear',
            'warmup_steps': 1000,
            'decay_schedule': 'inverse_time_decay',
            'decay_steps': 1000,  # in case of inverse_time_decay this parametrizes the characteristic time scale in terms of number of steps
            'min_rate': 0.0,  # has to be zero in case of inverse_time_decay
        },
    }

    dynamic_loss = {
        'with_dynamic_reference': True,
        'energy': {
            'measure': 'mse',  # 'mae', 'mse', 'huber', 'asinh'
            'scale_per_electron': True,
            'scale_parameter': 0.1,  # Hartree; only relevant for huber and asinh
        },
        'density': {
            'measure': 'mean_field',
            'scale_per_electron': True,
            'spin_restricted': base['spin_restricted'],
            'is_density_fitted': base['use_density_fitting'],
        },
        'xc_potential': {
            'measure': 'L2',
            'scale_per_electron': True,
            'per_sample_optimal_gauge': 'L2',  # 'none' or 'L1' or 'L2'
        },
        'orbital_rotation_gradient': {
            'measure': 'L2',
            'scale_per_entry': False,
            'oep_weighting': False,
        },
        'orbital_rotation_hessian': {
            'measure': 'L2',
            'scale_per_entry': False,
            'oep_weighting': False,
            'normalization': 'fro',  # 'fro', 'nuc', '1', '-1', '2', '-2'
            'n_perturbations': 8,
            'differentiate_through_ground_state': False,
        },
        'reference_basis_is_same': basis['name'] == data['deixc_method_kwargs']['basis'],
        'max_energy_volatility': 1e9,  # mEh
        'relative_weights': {
            'xc_energy': 0.0,
            'forces': 0.0,
            'xc_potential': 0.0,
            'orbital_rotation_gradient': 0.0,
            'orbital_rotation_hessian': 0.0,
            'total_energy': 1.0,  # in the dynamic reference stage this is to be understood as the final total energy
            'density': 0.0,
        },
        'vectorize_along_scf': {
            'xc_energy': ('egxc2024', 12),
            'forces': ('egxc2024', 12),
            'xc_potential': ('egxc2024', 8),
            'orbital_rotation_gradient': ('egxc2024', 8),
            # in the dynamic reference stage  is to be understood as the final total energy
            'total_energy': ('egxc2024', 12),
            'density': ('egxc2024', 12),
        },
    }
    dynamic_optimizer = {
        'name': 'adam',
        'restart_epochs': list[int](),
        'restart_lr_scales': list[float](),
        'ema_decay': 0.995,
        'additional_params': {'b1': 0.9, 'b2': 0.999},
        'weight_decay': 1e-5,
        'decay_only_graph_readout': False,
        'plateau_handling': {
            'factor': 0.2,
            'patience': 3,
            'cooldown': 10,
            'accumulation_size': 1000,
            'min_scale': 1e-12,
            'min_relative_improvement': 0.005,
        },
        'metropolis_stabilizer': {
            'method': 'outlier',
            'method_kwargs': {
                'warmup_steps': 20,
                'sigma_tol': 5.0,  # loss increase tolerance in units of the relative loss change scale estimate
                'min_sigma': 0.2,  # lower bound on the relative loss change scale estimate
                'max_sigma': 10.0,  # upper bound 10 -> defacto infinity
            },
            'initial_tries': 5,
            'reinit_during_tryouts': True,
            'loss_statistics_beta': 0.75,
            'consecutive_rejections_threshold': 2,
            'momentum_scaling_on_consecutive_reject': 0.7,  # is applied at every consecutive rejection above the threshold
        },
        'apply_every': 1,
        'clip_grad_max_norm': None,
        'skip_nans': 10,  # number of consecutive NaNs that can be skipped
    }
    dynamic_train = {
        'epochs': 1_000_000,
        'early_stopping_patience': 5,
        'early_stopping_min_relative_improvement': 0.0,
        'schedule': {
            'base_rate': 0.005,
            'warmup_schedule': 'linear',
            'warmup_steps': 1000,
            'decay_schedule': 'inverse_time_decay',
            'decay_steps': 750,  # in case of inverse_time_decay this parametrizes the characteristic time scale in terms of number of steps
            'min_rate': 0.0,  # has to be zero in case of inverse_time_decay
        },
    }


@ex.named_config
def scaled_nagai2020():
    """
    Hidden dimension scaled to 23 from 100 in the original Nagai2020 functional.
    This is done to match the number of trainable parameters of the XCDiff functional.
    """
    model = {
        'name': 'nagai2020',
        'local_n_layers': 4,
        'local_hidden_dim': 23,
    }


@ex.named_config
def nagai2020_orbital_free():
    """
    Scaled to fit on a100 GPU
    """
    model = {
        'name': 'nagai2020_orbital_free',
        'local_n_layers': 5,
        'local_hidden_dim': 24,
    }


@ex.named_config
def xcdiff_orbital_free():
    model = {
        'name': 'xcdiff_orbital_free',
        'local_n_layers': 4,
        'local_hidden_dim': 16,
    }


@ex.named_config
def scaled_skala_mgga():
    """
    Scaled to match parameter count of XCDiff functional.
    """
    model = {
        'name': 'skala_mgga',
        'local_n_layers': 4,
        'local_hidden_dim': 22,
    }


@ex.named_config
def scaled_deixc_mgga():
    """
    Hidden dimension scaled to match the number of trainable parameters of the XCDiff functional.
    """
    model = {
        'name': 'custom',
        'local': {
            'name': 'deixc',
            'local_n_layers': 4,
            'local_hidden_dim': 23,
        },
    }


@ex.named_config
def egxc():
    model = {
        'name': 'custom',
        'local': {
            'name': 'nagai2020',
            'local_n_layers': 6,
            'local_hidden_dim': 32,
        },
    }
    model['gnn'] = {
        'type': 'nequip',
        'encoder': {
            'cutoff': 5.0,  # Angstrom
            'num_radial_filters': 32,
            'nuclei_partitioning': 'Exponential',
            '_quadrature_points_per_atom_scaling': 12,  # ablate
        },
        'kwargs': {
            'irreps_str': '128x0e + 64x1o',  # per-L message passing dims
            'output_irreps_str': '16x0e + 16x1o',  # per-L output dims
            'message_cutoff': 5.0,
            'n_radial_basis': 8,
            'layers': 3,
            'init_graph_readout_to_zero': True,
            'energy_graph_readout_hidden_dims': (256, 256, 1),
        },
    }
    model['non_locality'] = {
        'grid_feature_mode': 'reweighting_without_mGGA_feats',
        'graph_readout': True,
        'decoder': {
            'spatial_feature_dim': 16,  # output spatial feature dimension
        },
        'reweighting': {
            'layers': 3,
            'hidden_dim': 16,
            'init_scale': 0.1,
            'output_activation_type': 'None',  # TODO: try scaled_sigmoid again
        },
    }


@ex.named_config
def egxc_orbital_free():
    model = {
        'name': 'custom',
        'local': {
            'name': 'nagai2020_orbital_free',
            'local_n_layers': 5,
            'local_hidden_dim': 24,
        },
    }
    model['gnn'] = {
        'type': 'nequip',
        'encoder': {
            'cutoff': 5.0,  # Angstrom
            'num_radial_filters': 32,
            'nuclei_partitioning': 'Exponential',
            '_quadrature_points_per_atom_scaling': 12,  # ablate
        },
        'kwargs': {
            'irreps_str': '128x0e + 64x1o',  # per-L message passing dims
            'output_irreps_str': '16x0e + 16x1o',  # per-L output dims
            'message_cutoff': 5.0,
            'n_radial_basis': 8,
            'layers': 3,
            'init_graph_readout_to_zero': True,
            'energy_graph_readout_hidden_dims': (256, 256, 1),
        },
    }
    model['non_locality'] = {
        'grid_feature_mode': 'reweighting_without_mGGA_feats',
        'graph_readout': True,
        'decoder': {
            'spatial_feature_dim': 16,  # output spatial feature dimension
        },
        'reweighting': {  # ablate these
            'layers': 3,
            'hidden_dim': 16,
            'init_scale': 0.1,
            'output_activation_type': 'None',  # TODO: try scaled_sigmoid again
        },
    }


@ex.named_config
def nagai2020_deeper():
    model = {
        'name': 'custom',
        'local': {
            'name': 'nagai2020',
            'local_n_layers': 6,
            'local_hidden_dim': 32,
        },
    }


@ex.named_config
def skala_mgga_deeper():
    model = {
        'name': 'custom',
        'local': {
            'name': 'skala_mgga',
            'local_n_layers': 6,
            'local_hidden_dim': 32,
        },
    }


@ex.named_config
def deixc2025():
    model = {
        'name': 'custom',
        'local': {
            'name': 'deixc',
            'local_n_layers': 6,
            'local_hidden_dim': 16,
        },
    }
    model['gnn'] = {
        'type': 'NequIP',
        'encoder': {
            # Note this cutoff is used in the decoder too
            'cutoff': 5.0,  # Angstrom
            'num_radial_filters': 16,
            'nuclei_partitioning': 'Exponential',
        },
        'kwargs': {
            'irreps_str': '256x0e + 128x1o',  # per-L message passing dims
            'output_irreps_str': '16x0e + 16x1o',  # per-L output dims
            # message passing parameters:
            'message_cutoff': 10.0,  # [Angstrom]  Note: KS-DFT scales O(N^3), hence all to all communication does not dominate the scaling
            'n_radial_basis': 64,
            'layers': 3,
            'init_graph_readout_to_zero': True,
            'energy_graph_readout_hidden_dims': (256, 256, 1),
        },
    }
    model['non_locality'] = {
        'grid_feature_mode': 'injection',  # 'local_only' or 'reweighting'
        'graph_readout': True,
        'decoder': {
            'spatial_feature_dim': 16,  # output spatial feature dimension
            # 'add_partitioning_as_feature': True,  # TODO: think whether this makes sense?
        },
        # 'reweighting': {
        #     'layers': 3,  # try larger?
        #     'hidden_dim': 16,
        # },
    }


@ex.named_config
def des370k():
    data = {
        'key': 'des370k',
    }
    alignment = {'atom': 4, 'basis': 8, 'grid': 1024 * 8}


@ex.named_config
def md17():
    dynamic_loss = {
        'max_energy_volatility': 200.0,
    }
    data = {'key': 'md17'}
    alignment = {'atom': 1, 'basis': 1, 'grid': 1}


@ex.named_config
def qm9():
    data = {
        'key': 'qm9',
        'data_set_kwargs': {'heavy_atoms_thresh': 4, 'exclude_fluorine': False},
    }


@ex.named_config
def qm9_debug():
    data = {'key': 'qm9', 'data_set_kwargs': {'heavy_atoms_thresh': 'debug'}}
    static_pretrain = {'epochs': 1, 'early_stopping_patience': 1}
    static_train = {'epochs': 1, 'early_stopping_patience': 1}
    dynamic_train = {'epochs': 1, 'early_stopping_patience': 1}


@ex.named_config
def qm9_debug_larger():
    data = {'key': 'qm9', 'data_set_kwargs': {'heavy_atoms_thresh': 'debug_larger'}}
    static_pretrain = {'epochs': 1, 'early_stopping_patience': 1}
    static_train = {'epochs': 1, 'early_stopping_patience': 1}
    dynamic_train = {'epochs': 1, 'early_stopping_patience': 1}


@ex.named_config
def qm9_4():
    data = {'key': 'qm9', 'data_set_kwargs': {'heavy_atoms_thresh': 4}}
    static_pretrain = {'early_stopping_patience': 200}
    static_train = {'early_stopping_patience': 200}
    dynamic_train = {'early_stopping_patience': 400}


@ex.named_config
def qm9_5():
    data = {'key': 'qm9', 'data_set_kwargs': {'heavy_atoms_thresh': 5}}
    static_pretrain = {'early_stopping_patience': 55}
    static_train = {'early_stopping_patience': 55}
    dynamic_train = {'early_stopping_patience': 110}


@ex.named_config
def qm9_6():
    data = {'key': 'qm9', 'data_set_kwargs': {'heavy_atoms_thresh': 6}}
    static_pretrain = {'early_stopping_patience': 12}
    static_train = {'early_stopping_patience': 12}
    dynamic_train = {'early_stopping_patience': 24}


@ex.named_config
def qm9_7():
    data = {'key': 'qm9', 'data_set_kwargs': {'heavy_atoms_thresh': 7}}
    static_pretrain = {'early_stopping_patience': 3}
    static_train = {'early_stopping_patience': 3}
    dynamic_train = {'early_stopping_patience': 6}


@ex.named_config
def scf():
    solver = {
        'solver': 'scf',
        'kwargs': {
            'cycles': 15,
            'convergence_acceleration_method': 'DIIS',
        },
    }


@ex.named_config
def lookahead():
    dynamic_optimizer = {
        'lookahead': {
            'slow_step_size': 0.5,
            'sync_period': 5,
            'adaptive_slow_step_size': False,
            'reset_state': False,
            'adam_b2': 0.999,
        },
    }


@ex.named_config
def interpolate_from_lda():
    initial_density_guess = {
        'base_initial_density_guess_key': 'minao',  # 'minao', 'atom'
        'initial_ref_density_method': {
            'key': 'ks_dft',
            'kwargs': {  # method_specific kwargs
                'xc_str': 'LDA',
                'backend': 'pyscf',  # pyscf or custom
                'use_eri_density_fitting': True,
                'use_exchange_density_fitting': True,  # LDA does not use hf exchange
                'quadrature_grid_level': 3,
            },
        },
        'min_ref_density_interpolation': 0.5,
        'max_ref_density_interpolation': 1.0,
    }


@ex.named_config
def init_with_lda():
    initial_density_guess = {
        'base_initial_density_guess_key': 'minao',  # 'minao', 'atom'
        'initial_ref_density_method': {
            'key': 'ks_dft',
            'kwargs': {  # method_specific kwargs
                'xc_str': 'LDA',
                'backend': 'pyscf',  # pyscf or custom
                'use_eri_density_fitting': True,
                'use_exchange_density_fitting': True,  # LDA does not use hf exchange
                'quadrature_grid_level': 3,
            },
        },
        'min_ref_density_interpolation': 1.0,
        'max_ref_density_interpolation': 1.0,
    }


@ex.named_config
def empty():  # empty config for A-B testing
    pass


class ExperimentWrapper:
    test: bool

    # physics model
    basis_str: str
    basis_derivative: int
    use_density_fitting: bool  # density fitted or exact
    target_functional_str: str
    # ml model
    module: XCModule
    # training
    seed: int
    epochs: int
    dynamic_loss_config: DynamicLossConfig
    static_loss_config: StaticLossConfig
    # optimizer configs
    static_pretrain_opt_config: OptConfig
    static_train_opt_config: OptConfig
    dynamic_train_opt_config: OptConfig
    # logging and checkpointing
    logger: Logger
    checkpointer: CheckpointManager
    # data
    dataloaders: dataloading.DataLoaders
    init_psys: PreloadSystem
    # padding
    alignment: Alignment
    # initial guess
    base_initial_density_guess_key: Literal['minao']
    initial_ref_density_method_key: MethodKey | None
    initial_ref_density_method_kwargs: Dict[str, Any] | None
    initial_ref_density_basis: str | None
    interpolation_min: float
    noise_eps: float
    # solver
    solver: DerivativeInformedSelfConsistentFieldSolver  # contains model
    grid_level: int  # quadrature
    spin_restricted: bool
    data_split_seed: int

    @ex.capture(prefix='logging')  # type: ignore
    def __init__(
        self,
        overwrite: int,
        project: str,
        dir: str,
        run_name: str,
        checkpointing: Dict[str, Any],
        benchmark: bool,
    ) -> None:
        # slurm_id = ex.current_run.
        run_name = f'{run_name}_{overwrite}'
        self.logger = Logger(
            project,
            ex.current_run.config,  # type: ignore
            dir=dir,
            name=run_name,
        )
        self.benchmark = benchmark
        self.set_all()
        self.get_initial_density_matrix_fn()  # type: ignore
        self.init_checkpointer(run_name=run_name, **checkpointing)
        self.checkpointer.save_config(ex.current_run.config)  # type: ignore
        self.init_solver()  # type: ignore

    def set_all(self) -> None:
        """
        sets all independent class members
        """
        self.set_base()  # type: ignore
        self.set_alignment()  # type: ignore
        self.set_quadrature()  # type: ignore
        self.set_basis()  # type: ignore
        self.init_optimizer_configs()  # type: ignore

    @ex.capture(prefix='model')  # type: ignore
    def init_checkpointer(self, name, run_name, directory) -> None:
        self.checkpointer = CheckpointManager(
            directory,
            model_name=name,
            basis=self.basis_str,
            name=run_name,
            data_split_seed=self.data_split_seed,
        )

    @ex.capture(prefix='base')  # type: ignore
    def set_base(
        self,
        test: bool,
        seed: int,
        use_density_fitting: bool,
        spin_restricted: bool,
    ):
        self.test = test
        self.seed = seed
        self.use_density_fitting = use_density_fitting
        self.spin_restricted = spin_restricted

    @ex.capture(prefix='alignment')  # type: ignore
    def set_alignment(self, atom: int, basis: int, grid: int) -> None:
        self.alignment = Alignment(atom, basis, grid)

    @ex.capture(prefix='quadrature')  # type: ignore
    def set_quadrature(self, level: int) -> None:  # called by init_input_transform
        self.grid_level = level

    @ex.capture(prefix='basis')  # type: ignore
    def set_basis(self, name: str, derivative: int) -> None:
        self.basis_str = name
        self.basis_derivative = derivative

    @ex.capture
    def init_optimizer_configs(
        self, static_optimizer: Dict[str, Any], dynamic_optimizer: Dict[str, Any]
    ) -> None:
        self._init_static_pretrain_optimizer(**static_optimizer)  # type: ignore
        self._init_static_train_optimizer(**static_optimizer)  # type: ignore
        self._init_dynamic_train_optimizer(**dynamic_optimizer)  # type: ignore

    @ex.capture(prefix='static_pretrain')  # type: ignore
    def _init_static_pretrain_optimizer(
        self,
        epochs: int,
        early_stopping_patience: int,
        early_stopping_min_relative_improvement: float,
        schedule: Dict[str, Any],
        **kwargs,
    ):
        assert epochs == 0, 'static pretraining is presently not maintained'
        self.static_pretrain_opt_config = OptConfig.create(
            epochs=epochs,
            early_stopping_patience=early_stopping_patience,
            early_stopping_min_relative_improvement=early_stopping_min_relative_improvement,
            schedule=schedule,
            **kwargs,
        )

    @ex.capture(prefix='static_train')  # type: ignore
    def _init_static_train_optimizer(
        self,
        epochs: int,
        early_stopping_patience: int,
        early_stopping_min_relative_improvement: float,
        schedule: Dict[str, Any],
        **kwargs,
    ):
        assert epochs == 0, 'static training is presently not maintained'
        self.static_train_opt_config = OptConfig.create(
            epochs=epochs,
            early_stopping_patience=early_stopping_patience,
            early_stopping_min_relative_improvement=early_stopping_min_relative_improvement,
            schedule=schedule,
            **kwargs,
        )

    @ex.capture(prefix='dynamic_train')  # type: ignore
    def _init_dynamic_train_optimizer(
        self,
        epochs: int,
        early_stopping_patience: int,
        early_stopping_min_relative_improvement: float,
        schedule: Dict[str, Any],
        **kwargs,
    ):
        self.dynamic_train_opt_config = OptConfig.create(
            epochs=epochs,
            early_stopping_patience=early_stopping_patience,
            early_stopping_min_relative_improvement=early_stopping_min_relative_improvement,
            schedule=schedule,
            **kwargs,
        )

    @ex.capture(prefix='initial_density_guess')  # type: ignore
    def get_initial_density_matrix_fn(
        self,
        base_initial_density_guess_key: Literal['minao'],
        initial_ref_density_method: Dict[str, Any] | None,
        min_ref_density_interpolation: float,
        max_ref_density_interpolation: float,
        noise_eps: float,
    ) -> None:
        self.base_initial_density_guess_key = base_initial_density_guess_key
        if initial_ref_density_method is None:
            assert min_ref_density_interpolation == 0.0, (
                'min_ref_density_interpolation must be 0.0 when no reference density is provided'
            )
            assert max_ref_density_interpolation == 0.0, (
                'max_ref_density_interpolation must be 0.0 when no reference density is provided'
            )
            self.initial_ref_density_method_key = None
            self.initial_ref_density_method_kwargs = None
        else:
            kwargs = deepcopy(initial_ref_density_method['kwargs'])
            kwargs['spin_restricted'] = self.spin_restricted
            kwargs['basis'] = self.basis_str
            self.initial_ref_density_method_key = initial_ref_density_method['key']
            self.initial_ref_density_method_kwargs = kwargs
        self.initial_density_matrix_fn = (
            dataloading.transform.get_initial_density_matrix_fn(
                min_ref_density_interpolation=min_ref_density_interpolation,
                max_ref_density_interpolation=max_ref_density_interpolation,
                noise_eps=noise_eps,
            )
        )
        self._init_dataset()  # type: ignore

    @ex.capture(prefix='data')  # type: ignore
    def _init_dataset(
        self,
        key: str,
        data_set_kwargs: Dict[str, Any],
        deixc_method_key: MethodKey,
        deixc_method_kwargs: Dict[str, Any],
        align_scf_trajectory: int,
        shift_dispersion: bool,
    ) -> None:
        dataset: dataloading.BaseDataset = dataloading.key_to_dataset[key.lower()](
            initial_ref_density_method_key=self.initial_ref_density_method_key,
            initial_ref_density_method_kwargs=self.initial_ref_density_method_kwargs,
            **data_set_kwargs,
        )
        self.target_functional_str = deixc_method_kwargs['xc_str']
        self.dataset = DEIXCDataset(
            dataset,
            deixc_method_key,
            method_kwargs=deixc_method_kwargs,
            align_scf_trajectory=align_scf_trajectory,
            shift_dispersion=shift_dispersion,
        )
        self._init_dataloader()  # type: ignore
        self._init_main_thread_transform()  # type: ignore

    @ex.capture(prefix='data')  # type: ignore
    def _init_dataloader(
        self,
        split,
        batch_size: int,
        shuffle: bool,
        workers: int | None,
        worker_buffer_size: int,
        data_split_seed: int,
        n_test_samples: int | None,
        preload: Dict[str, bool],
    ) -> None:
        self.data_split_seed = data_split_seed
        dataset_ensemble = dataloading.DatasetEnsemble.infer_split(
            self.dataset, data_split_seed=self.data_split_seed, **split
        )
        self.max_angular_momentum, basis_fn_preloader = get_gto_preloader(
            self.basis_str, self.dataset.unique_elements
        )
        preload_transformations = dataloading.get_preload_transform(
            batch_size,
            self.basis_str,
            self.spin_restricted,
            self.alignment,
            self.use_density_fitting,
            base_initial_density_guess=self.base_initial_density_guess_key,
            basis_fn_preloader=basis_fn_preloader,
            **preload,
        )
        self.init_psys, self.dataloaders = dataloading.get_psys_and_dataloaders(
            dataset_ensemble,
            preload_transformations,
            shuffle,
            workers,
            worker_buffer_size,
            self.seed,  # Note that while the split should stay constant, the shuffling should be random
            n_test_samples,
        )
        self.init_preloaded_basis_fns = basis_fn_preloader(self.init_psys.atom_z)

    @ex.capture(prefix='data')  # type: ignore
    def _init_main_thread_transform(self, preload: Dict[str, bool]) -> None:
        grid_fn = get_grid_fn(
            self.grid_level, self.dataset.unique_elements, self.alignment.grid
        )

        basis_fn = get_gto_grid_eval_fn(
            deriv=self.basis_derivative,
            max_angular_momentum=self.max_angular_momentum,
        )

        self.main_thread_transform = dataloading.get_jax_transform(
            grid_fn,
            basis_fn,
        )

    @ex.capture(prefix='solver')  # type: ignore
    def init_solver(self, solver: str, kwargs: Dict[str, Any]) -> None:
        model = self._init_xc_module()  # type: ignore
        if solver == 'scf':
            self.cycles = kwargs['cycles']
            self.solver = DerivativeInformedSelfConsistentFieldSolver(
                model,
                use_density_fitting=self.use_density_fitting,
                spin_restricted=self.spin_restricted,
                **kwargs,
            )
        elif solver == 'direct_minimization':
            # TODO: Implement direct minimization solver
            pass
        else:
            raise ValueError(f'Unknown solver: {solver}')
        self._init_loss_configs()  # type: ignore

    @ex.capture  # type: ignore
    def _init_xc_module(self, model: Dict[str, Any]) -> XCModule:
        # Called by init_solver
        functional = functionals.get_functional(**model)
        # Select appropriate density feature computation based on functional needs
        requires_sr = getattr(functional, 'requires_spin_resolved_features', False)
        feature_fn = DensityFeatures(self.spin_restricted, spin_resolved=requires_sr)
        module = XCModule(functional, feature_fn)
        return module

    @ex.capture
    def _init_loss_configs(
        self,
        static_loss: Dict[str, Any],
        dynamic_loss: Dict[str, Any],
    ) -> None:
        self.static_loss_config = StaticLossConfig.create(
            **static_loss,
            ref_scf_cycles=self.dataset.align_scf_trajectory,
        )
        self.dynamic_loss_config = DynamicLossConfig.create(
            **dynamic_loss,
            scf_cycles=self.cycles,
        )

    def __call__(self) -> None:
        params, params_init_fn, checkpoint_type = self._get_initial_model_params()  # type: ignore
        base_kwargs = {
            'solver': self.solver,
            'dataloaders': self.dataloaders,
            'main_thread_transform': self.main_thread_transform,
            'logger': self.logger,
            'checkpointer': self.checkpointer,
            'benchmark': self.benchmark,
        }

        bare_functional_params = params['params']['xc_module']['functional']
        has_local_model = 'local_model' in bare_functional_params

        # Execute training phases based on parameter type
        if checkpoint_type == 'scratch':
            if has_local_model:
                functional_params = {
                    'params': {'functional': bare_functional_params['local_model']}
                }
            else:
                functional_params = {'params': {'functional': bare_functional_params}}

            functional_params = run(
                ('static', 'pretrain'),
                functional_params,
                params_init_fn,
                **base_kwargs,
                opt_config=self.static_pretrain_opt_config,
                loss_config=self.static_loss_config,  # type: ignore
                prng_key=None,  # type: ignore
            )
            bare_functional_params = functional_params['params']['functional']
            if has_local_model:  # insert trained local model params into solver params
                params['params']['xc_module']['functional']['local_model'] = (
                    bare_functional_params
                )
            else:
                params['params']['xc_module']['functional'] = bare_functional_params
            # After pretraining, ensure we operate on the full functional (not only the local model)
            bare_functional_params = params['params']['xc_module']['functional']

        if checkpoint_type in ['scratch', 'static_pretrain']:
            functional_params = {'params': {'functional': bare_functional_params}}
            functional_params = run(
                ('static', 'train'),
                functional_params,
                params_init_fn,
                **base_kwargs,
                opt_config=self.static_train_opt_config,
                loss_config=self.static_loss_config,
                prng_key=None,  # type: ignore
            )
            bare_functional_params = functional_params['params']['functional']
            params['params']['xc_module']['functional'] = bare_functional_params
        if checkpoint_type in [
            'scratch',
            'static_pretrain',
            'static_train',
            'dynamic_train',  # keep training if epochs > 0
        ]:
            params = run(
                ('dynamic', 'train'),
                params,
                params_init_fn,
                **base_kwargs,
                initial_density_matrix_fn=self.initial_density_matrix_fn,
                max_energy_volatility=self.dynamic_loss_config.max_energy_volatility,
                opt_config=self.dynamic_train_opt_config,
                loss_config=self.dynamic_loss_config,
                prng_key=jax.random.PRNGKey(self.seed),
                ref_functional=self._get_ref_functional(),
            )
        if self.test:
            final_evaluation(
                params,
                self.solver,
                self.dynamic_loss_config,
                self.dataloaders,
                self.main_thread_transform,
                self.logger,
                self.initial_density_matrix_fn,
            )

    def _get_ref_functional(self) -> XCModule | None:
        if self.dynamic_loss_config.with_dynamic_reference:
            functional = functionals.get_functional(
                self.target_functional_str,
                spin_restricted=self.spin_restricted,
                use_density_fitting=self.use_density_fitting,
            )
            requires_sr = getattr(functional, 'requires_spin_resolved_features', False)
            feature_fn = DensityFeatures(self.spin_restricted, spin_resolved=requires_sr)
            return XCModule(functional, feature_fn)
        else:
            return None

    @ex.capture(prefix='initial_model_params')  # type: ignore
    def _get_initial_model_params(
        self,
        load_from: str | None,
        checkpoint_type: CheckpointType,
        load_as_local_params: bool,
    ) -> Tuple[NnParams, Callable[[int], NnParams] | None, CheckpointType]:
        """
        Loads the initial model parameters from a checkpoint or initializes them from scratch.
        NOTE: Currently static pretraining is not maintained.

        Args:
            load_from: The path to the checkpoint to load from.
            checkpoint_type: The checkpoint type.
            load_as_local_params: Whether to load only the local model parameters.

        Returns:
            params: The initial model parameters.
            params_init_fn : A function that initializes the model parameters.
            checkpoint_type: The checkpoint type.
        """
        P0s, sys = self.main_thread_transform(
            self.init_psys, self.init_preloaded_basis_fns
        )

        n_tries = self.dynamic_train_opt_config.metropolis_stabilizer.initial_tries
        if n_tries > 1:
            assert checkpoint_type == 'scratch', (
                'Only one ensemble member is supported when loading from checkpoint'
            )
            ensemble_keys = jax.random.split(jax.random.PRNGKey(self.seed), n_tries)
            params_init_fn = lambda i: self.solver.init(ensemble_keys[i], P0s[0], sys)
            params = params_init_fn(0)
        else:
            params_init_fn = lambda _: self.solver.init(
                jax.random.PRNGKey(self.seed), P0s[0], sys
            )
            params = params_init_fn(0)

        if load_from is not None:
            assert checkpoint_type != 'scratch', (
                'Cannot load from checkpoint when initializing from scratch'
            )
            loaded_params = self.checkpointer.load_params(
                load_from, prefix=checkpoint_type
            )
            loaded_params = loaded_params['params']

            # correctly insert the parameters into the model structure
            if 'pretrain' in checkpoint_type:
                # only load the local model parameters, branch for composite non-local models
                if hasattr(self.solver.xc_module.functional, 'local_model'):
                    src = loaded_params['xc_module']['functional']['local_model']
                    params['params']['xc_module']['functional']['local_model'] = src
                else:
                    src = loaded_params['xc_module']['functional']
                    params['params']['xc_module']['functional'] = src
            elif load_as_local_params:
                assert hasattr(self.solver.xc_module.functional, 'local_model'), (
                    'Local model not found in functional'
                )
                src = loaded_params['xc_module']['functional']
                params['params']['xc_module']['functional']['local_model'] = src
            else:
                # For training (non-pretrain), load the entire functional
                src = loaded_params['xc_module']['functional']
                params['params']['xc_module']['functional'] = src

        return params, params_init_fn, checkpoint_type


@ex.automain
def main(profile_device_memory: bool, overwrite: int):
    exp = ExperimentWrapper(overwrite)  # type: ignore
    exp()
    if profile_device_memory:
        jax.profiler.save_device_memory_profile('memory_profile.prof')
