from typing import Any, Dict, Literal

import jax
import jax.profiler
import wandb
from jax import config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from egxc import dataloading
from egxc.discretization import get_grid_fn, get_gto_grid_eval_fn, get_gto_preloader
from egxc.solver.scf import SelfConsistentFieldSolver
from egxc.systems import PreloadSystem
from egxc.training import loss, pretrain, run
from egxc.training.optimizer import OptConfig, get_optimizer
from egxc.utils.checkpointing import CheckpointManager
from egxc.utils.logging import Logger
from egxc.utils.typing import Alignment, NnParams
from egxc.xc_energy import DensityFeatures, XCModule, functionals
from egxc.xc_energy.functionals.learnable.egxc import EGXC

compilation_cache.set_cache_dir('./caches/jax_compile/egxc')

config.update('jax_platform_name', 'gpu')
config.update('jax_enable_x64', True)
config.update('jax_default_matmul_precision', 'float32')
# config.update('jax_debug_nans', True)

ex = Experiment()

if __name__ == '__main__':
    wandb.login()


@ex.config
def default_config():
    profile_device_memory = False
    run_seed = 0
    logging = {
        'project': 'egxc',
        'dir': 'ANONYMOUS_DIR',
        'run_name': None,
        'checkpointing': {
            'directory': None,
        },
    }
    base = {
        'seed': run_seed,
        'test': False,
        'epochs': 1_000_000,
        'use_density_fitting': True,
        'spin_restricted': True,
    }
    alignment = {
        'atom': 4,
        'basis': 1,
        'grid': 512,
    }
    basis = {
        'name': '6-31G(d)',  # 'sto-6g', '6-31G(d)' '6-31G(2df,p)' '6-311++G(3df,2pd)'
        'derivative': 1,
    }
    quadrature = {
        'level': 1,
    }
    initial_density_guess = {
        'base_initial_density_guess_key': 'minao',
        'initial_ref_density_method': 'dft_lda',
        'initial_ref_density_basis': basis['name'],
        'min_ref_density_interpolation': 0.5,
        'max_ref_density_interpolation': 1.0,
        'noise_eps': 0.0,
    }
    # Top-level load_from option to initialize complete model parameters
    load_from = None
    data = {
        'workers': 8,
        'worker_buffer_size': 1,
        'batch_size': 1,
        'shuffle': True,
        'seed': 0,  # random seed for splitting the dataset
        'preload': {
            'center': False,
        },
        'split': {'val_fraction': 0.1},
    }
    loss = {
        'discard_first_n': 10,
        'decay_type': 'dick2021',
        'relative_weights': {
            'energy': 1.0,
            'density': 0.0,
        },
        'density': {
            'measure': 'mean_field',
            'scale_per_electron': True,
            'spin_restricted': base['spin_restricted'],
            'is_density_fitted': base['use_density_fitting'],
        },
        'max_energy_volatility': 1e9,  # mEh
    }
    optimizer = {
        'early_stopping_patience': 5,
        'early_stopping_min_relative_improvement': 0.0,
        'ema_decay': 0.995,
        'kwargs': {
            'name': 'adam',
            'additional_params': {
                'b1': 0.9,
                'b2': 0.999,
            },
            'weight_decay': 0.0,
            'schedule': {
                'base_rate': 0.005,
                'warmup_schedule': 'linear',
                'warmup_steps': 2_000,
                'decay_schedule': 'inverse_time_decay',
                'decay_steps': 1000,  # in case of inverse_time_decay this parametrizes the characteristic time scale in terms of number of steps
                'min_rate': 0.0,  # has to be zero in case of inverse_time_decay
            },
            'plateau_handling': {
                'factor': 0.25,
                'patience': 3,
                'cooldown': 3,
                'accumulation_size': 1000,
                'min_scale': 1e-12,
                'min_relative_improvement': 0.0001,
            },
            'apply_every': 1,
            'clip_grad_max_norm': None,
            'skip_nans': 10,  # number of consecutive NaNs that can be skipped
        },
    }
    pretraining = {  # add load_model_weights option
        'early_stopping_patience': 5,
        'early_stopping_min_relative_improvement': 0.0,
        'ema_decay': 0.995,
        'epochs': 100_000,
        'load_from': None,
        'opt_kwargs': {
            'additional_params': None,
            'weight_decay': 0.0,
            'schedule': {  # best for ethanol + dick2021
                'warmup_schedule': 'linear',
            },
            'apply_every': 1,
            'clip_grad_max_norm': None,
            'skip_nans': 1,  # number of consecutive NaNs that can be skipped
        },
        'plateau_handling': None,
    }


@ex.named_config
def md17_adam_pretraining_xcdiff():
    pretraining = {  # optimized for ethanol
        'opt_kwargs': {
            'name': 'adam',
            'additional_params': {
                'b1': 0.9,
                'b2': 0.999,
            },
            'weight_decay': 2e-4,
            'schedule': {
                'base_rate': 0.05,
                'min_rate': 1e-8,
                'warmup_steps': 3_000,
                'decay_steps': 7_000,
                'decay_schedule': 'cosine',
            },
            'plateau_handling': {
                'factor': 0.25,
                'patience': 3,
                'cooldown': 3,
                'accumulation_size': 1000,
                'min_scale': 1e-12,
                'min_relative_improvement': 0.0001,
            },
            'clip_grad_max_norm': 0.3,
            'skip_nans': 10,
        },
    }


@ex.named_config
def qm9_adam_pretraining_xcdiff():
    pretraining = {
        'opt_kwargs': {
            'name': 'adam',
            'additional_params': {
                'b1': 0.7,
                'b2': 0.9,
            },
            'schedule': {
                'base_rate': 0.005,
                'min_rate': 0,
                'warmup_steps': 100,
                'decay_steps': 1000,
                'decay_schedule': 'inverse_time_decay',
            },
            'plateau_handling': {
                'factor': 0.5,
                'patience': 5,
                'cooldown': 3,
                'accumulation_size': 1000,
                'min_scale': 1e-12,
                'min_relative_improvement': 0.0001,
            },
        },
    }


@ex.named_config
def des370k():
    data = {
        'key': 'des370k',
    }
    alignment = {'atom': 4, 'basis': 8, 'grid': 1024 * 8}


@ex.named_config
def md17():
    loss = {'max_energy_volatility': 200.0}  # mEh  # TODO: might not be needed
    data = {'key': 'md17'}
    alignment = {'atom': 1, 'basis': 1, 'grid': 1}


@ex.named_config
def qm9():
    loss = {'discard_first_n': 12}
    data = {
        'key': 'qm9',
        'data_set_kwargs': {
            'heavy_atoms_thresh': 4,
            'exclude_fluorine': True,
        },
    }
    alignment = {'atom': 1, 'basis': 1, 'grid': 1024 * 32}  # TODO: fix padding


@ex.named_config
def scf():
    solver = {
        'solver': 'scf',
        'args': {
            'cycles': 15,
            'convergence_acceleration_method': 'DIIS',
        },
    }


@ex.named_config
def egxc_painn():
    model = {
        'name': 'custom',
        'local': {
            'name': 'xcdiff',
        },
        'graph': {
            'irreps': '0e + 1o + 2e',
            'with_reweighting': True,
            'with_graph_readout': True,
            'atom_output_feature_dim': 16,
            'encoder': {
                # Note this cutoff is used in the decoder too
                'cutoff': 5.0,  # Angstrom  TODO: check units
                'num_radial_filters': 32,
            },
            'gnn_type': 'PaiNN',
            'gnn': {
                'node_feature_dim': 64,
                # message passing parameters:
                'cutoff': 5.0,
                'layers': 3,
            },
            'reweighting': {
                'layers': 3,  # try larger?
                'hidden_dim': 16,
            },
        },
    }


class ExperimentWrapper:
    test: bool

    # physics model
    basis_str: str
    basis_derivative: int
    use_density_fitting: bool
    # ml model
    module: XCModule
    # training
    seed: int
    epochs: int
    loss_config: loss.LossConfig
    # logging and checkpointing
    logger: Logger
    checkpointer: CheckpointManager
    # data
    dataloaders: dataloading.DataLoaders
    init_psys: PreloadSystem
    # padding
    alignment: Alignment
    # initial guess
    base_initial_density_guess_key: Literal['minao']
    initial_ref_density_method: str
    initial_ref_density_basis: str
    interpolation_min: float
    noise_eps: float
    max_angular_momentum: int
    # solver
    solver: SelfConsistentFieldSolver  # contains model
    grid_level: int  # quadrature
    spin_restricted: bool

    @ex.capture(prefix='logging')  # type: ignore
    def __init__(
        self,
        overwrite: int,
        project: str,
        dir: str,
        run_name: str | None,
        checkpointing: Dict[str, Any],
    ) -> None:
        # slurm_id = ex.current_run.
        run_name = f'{run_name}_{overwrite}'
        self.logger = Logger(
            project,
            ex.current_run.config,  # type: ignore
            dir=dir,
            name=run_name,
        )
        self.set_all()
        self.get_initial_density_matrix_fn()  # type: ignore
        self.init_checkpointer(run_name=run_name, **checkpointing)
        self.checkpointer.save_config(ex.current_run.config)  # type: ignore
        self.init_solver()  # type: ignore

    def set_all(self) -> None:
        """
        sets all independent class members
        """
        self.set_base()  # type: ignore
        self.set_alignment()  # type: ignore
        self.set_quadrature()  # type: ignore
        self.set_basis()  # type: ignore
        self.set_optimizer_config()  # type: ignore

    @ex.capture(prefix='model')  # type: ignore
    def init_checkpointer(self, name, run_name, directory) -> None:
        self.checkpointer = CheckpointManager(
            directory,
            model_name=name,  # model name
            basis=self.basis_str,
            name=run_name,
            data_split_seed=self.data_split_seed,
        )

    @ex.capture(prefix='base')  # type: ignore
    def set_base(
        self,
        test: bool,
        seed: int,
        epochs: int,
        use_density_fitting: bool,
        spin_restricted: bool,
    ):
        self.test = test
        self.seed = seed
        self.epochs = epochs

        self.use_density_fitting = use_density_fitting
        self.spin_restricted = spin_restricted

    @ex.capture(prefix='alignment')  # type: ignore
    def set_alignment(self, atom: int, basis: int, grid: int) -> None:
        self.alignment = Alignment(atom, basis, grid)

    @ex.capture(prefix='quadrature')  # type: ignore
    def set_quadrature(self, level: int) -> None:  # called by init_input_transform
        self.grid_level = level

    @ex.capture(prefix='basis')  # type: ignore
    def set_basis(self, name: str, derivative: int) -> None:
        self.basis_str = name
        self.basis_derivative = derivative

    @ex.capture(prefix='optimizer')  # type: ignore
    def set_optimizer_config(
        self,
        ema_decay: float,
        early_stopping_patience: int,
        early_stopping_min_relative_improvement: float,
        kwargs: Dict[str, Any],
    ) -> None:
        self.ema_decay = ema_decay
        self.early_stopping_patience = early_stopping_patience
        self.early_stopping_min_relative_improvement = (
            early_stopping_min_relative_improvement
        )
        # OptConfig requires epochs, ema_decay, and early_stopping_patience
        self.opt_config = OptConfig.create(
            epochs=self.epochs,
            ema_decay=ema_decay,
            early_stopping_patience=early_stopping_patience,
            early_stopping_min_relative_improvement=early_stopping_min_relative_improvement,
            **kwargs,
        )

    @ex.capture(prefix='initial_density_guess')  # type: ignore
    def get_initial_density_matrix_fn(
        self,
        base_initial_density_guess_key: Literal['minao'],
        initial_ref_density_method: str,
        initial_ref_density_basis: str,
        min_ref_density_interpolation: float,
        max_ref_density_interpolation: float,
        noise_eps: float,
    ) -> None:
        self.base_initial_density_guess_key = base_initial_density_guess_key
        self.initial_ref_density_method = initial_ref_density_method
        self.initial_ref_density_basis = initial_ref_density_basis
        self.initial_density_matrix_fn = (
            dataloading.transform.get_initial_density_matrix_fn(
                min_ref_density_interpolation=min_ref_density_interpolation,
                max_ref_density_interpolation=max_ref_density_interpolation,
                noise_eps=noise_eps,
            )
        )
        self._init_dataset()  # type: ignore

    @ex.capture(prefix='data')  # type: ignore
    def _init_dataset(
        self,
        key: str,
        data_set_kwargs: Dict[str, Any],
    ) -> None:
        self.dataset: dataloading.BaseDataset = dataloading.key_to_dataset[key.lower()](
            initial_ref_density_method=self.initial_ref_density_method,
            initial_ref_density_basis=self.initial_ref_density_basis,
            **data_set_kwargs,
        )
        self._init_dataloader()  # type: ignore
        self._init_main_thread_transform()  # type: ignore

    @ex.capture(prefix='data')  # type: ignore
    def _init_dataloader(
        self,
        split,
        batch_size: int,
        shuffle: bool,
        workers: int | None,
        worker_buffer_size: int,
        seed: int,
        preload: Dict[str, bool],
    ) -> None:
        self.data_split_seed = seed
        dataset_ensemble = dataloading.DatasetEnsemble.infer_split(
            self.dataset, data_split_seed=self.data_split_seed, **split
        )
        # Prepare vectorized basis preloader and preload transformations
        self.max_angular_momentum, basis_fn_preloader = get_gto_preloader(
            self.basis_str, self.dataset.unique_elements
        )
        preload_transformations = dataloading.get_preload_transform(
            batch_size,
            self.basis_str,
            self.spin_restricted,
            self.alignment,
            self.use_density_fitting,
            base_initial_density_guess=self.base_initial_density_guess_key,
            center=preload['center'],
            basis_fn_preloader=basis_fn_preloader,
        )
        self.init_psys, self.dataloaders = dataloading.get_psys_and_dataloaders(
            dataset_ensemble,
            preload_transformations,
            shuffle,
            workers,
            worker_buffer_size,
            seed,
        )
        self.init_pvec_basis_fns = basis_fn_preloader(self.init_psys.atom_z)

    @ex.capture(prefix='data')  # type: ignore
    def _init_main_thread_transform(self, preload: Dict[str, bool]) -> None:
        elements = self.dataset.unique_elements
        grid_fn = get_grid_fn(self.grid_level, elements, self.alignment.grid)
        basis_fn = get_gto_grid_eval_fn(
            deriv=self.basis_derivative, max_angular_momentum=self.max_angular_momentum
        )
        transform = dataloading.get_jax_transform(
            grid_fn,
            basis_fn,
        )
        self.main_thread_transform = transform

    @ex.capture(prefix='solver')  # type: ignore
    def init_solver(self, solver: str, args: Dict[str, Any]) -> None:
        model = self._xc_module()  # type: ignore
        if solver == 'scf':
            self.solver = SelfConsistentFieldSolver(
                model,
                use_density_fitting=self.use_density_fitting,
                spin_restricted=self.spin_restricted,
                **args,
            )
        elif solver == 'direct_minimization':
            # TODO: Implement direct minimization solver
            pass
        else:
            raise ValueError(f'Unknown solver: {solver}')
        self._init_loss_config()  # type: ignore

    @ex.capture  # type: ignore
    def _xc_module(self, model) -> XCModule:
        # Called by init_solver
        functional = functionals.get_functional(**model)
        module = XCModule(functional, DensityFeatures(self.spin_restricted))
        return module

    @ex.capture(prefix='loss')  # type: ignore
    def _init_loss_config(
        self,
        discard_first_n: int,
        decay_type: Literal['dick2021', 'li2021', 'egxc2024', 'only_final'],
        relative_weights: Dict[str, float],
        max_energy_volatility: float,
        density: Dict[str, Any],
    ) -> None:
        self.loss_config = loss.LossConfig.create(
            self.solver.cycles,
            discard_first_n,
            decay_type,
            relative_weights,
            max_energy_volatility,
            density,
        )

    def __call__(self) -> None:
        opt = get_optimizer(self.opt_config)
        model = self.solver

        init_params = self._get_initial_model_params(model)  # type: ignore

        run(
            init_params,
            model,
            opt,
            self.ema_decay,
            self.early_stopping_patience,
            self.early_stopping_min_relative_improvement,
            self.loss_config,
            self.epochs,
            self.dataloaders,
            self.main_thread_transform,
            self.initial_density_matrix_fn,
            self.logger,
            self.checkpointer,
            self.test,
            jax.random.PRNGKey(self.seed),
        )

    @ex.capture  # type: ignore
    def _get_initial_model_params(
        self, model: SelfConsistentFieldSolver, load_from: str | None
    ) -> NnParams:
        # Priority: top-level load_from > pretraining logic
        if load_from is not None:
            # Load complete model parameters directly
            complete_params = self.checkpointer.load_params(load_from, prefix='')
            return complete_params
        else:
            # Use existing pretraining logic
            return self._get_pretrained_params(model)  # type: ignore

    @ex.capture(prefix='pretraining')  # type: ignore
    def _get_pretrained_params(
        self, model: SelfConsistentFieldSolver, load_from: str | None
    ) -> NnParams:
        if load_from is not None:
            pretrained_params = self.checkpointer.load_params(
                load_from, prefix='pretrain'
            )
        else:
            pretrained_params = self._pretrain(model)  # type: ignore

        P0s, sys = self.main_thread_transform(self.init_psys, self.init_pvec_basis_fns)
        init_params = model.init(jax.random.PRNGKey(self.seed), P0s[0], sys)
        if pretrained_params and type(model.xc_module.functional) is EGXC:
            # overwrite local params
            init_params['params']['xc_module']['functional']['local_model'] = (
                pretrained_params['params']['xc_module']['functional']
            )
        else:
            init_params = pretrained_params
        return init_params

    @ex.capture(prefix='pretraining')  # type: ignore
    def _pretrain(
        self,
        model: SelfConsistentFieldSolver,
        epochs: int,
        ema_decay: float,
        early_stopping_patience: int,
        early_stopping_min_relative_improvement: float,
        opt_kwargs: Dict[str, Any],
    ) -> NnParams:
        if epochs > 0:
            opt_config = OptConfig.create(**opt_kwargs)
            opt = get_optimizer(opt_config)
            if type(model.xc_module.functional) is EGXC:
                xc_module = model.xc_module.copy(
                    functional=model.xc_module.functional.local_model
                )
                model = model.copy(xc_module=xc_module)
            # run pre-training
            P0s, sys = self.main_thread_transform(
                self.init_psys, self.init_pvec_basis_fns
            )
            local_params = model.init(jax.random.PRNGKey(self.seed), P0s[0], sys)
            # only train if parameters dict is not empty
            if local_params:
                local_params = pretrain.run(
                    local_params,
                    model,
                    opt,
                    ema_decay,
                    early_stopping_patience,
                    early_stopping_min_relative_improvement,
                    self.loss_config,
                    epochs,
                    self.dataloaders,
                    self.main_thread_transform,
                    self.initial_density_matrix_fn,
                    self.logger,
                    self.checkpointer,
                    jax.random.PRNGKey(self.seed),
                )
                jax.clear_caches()  # free all compiled functions
            return local_params
        else:
            return {}


@ex.automain
def main(profile_device_memory: bool, overwrite: int):
    exp = ExperimentWrapper(overwrite)  # type: ignore
    exp()
    if profile_device_memory:
        jax.profiler.save_device_memory_profile('memory_profile.prof')
