"""Energy-only finetuning script for DEI-XC models.

This script finetunes a pretrained model using only energy loss against original
dataset labels (e.g., CCSD(T) for MD17). Key features:
- Uses base datasets directly (no DEIXCDataset wrapper)
- Compares against original dataset energy labels
- Supports per-molecule energy shifts
- Uses Adam/Muon optimizer with warmup + decay schedule

Usage:
    python scripts/deixc_finetune_energy.py with <config_overrides>
    seml <collection> add configs/deixc/finetune_energy_md17.yaml
"""

import pickle
from pathlib import Path
from typing import Any, Callable, Dict, List, Sequence, Tuple

import grain.python as grain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
from jax import config as jax_config
from jax.experimental.compilation_cache import compilation_cache
from seml.experiment import Experiment

from egxc import dataloading, discretization, systems, xc_energy
from egxc.dataloading.dataloader import get_psys_and_dataloaders
from egxc.dataloading.datasets.base import Targets
from egxc.dataloading.datasets.ensemble import DatasetEnsemble
from egxc.dataloading.io import (
    checkpoint_best_path,
    results_directory,
    results_metrics_csv_path,
    results_summary_path,
)
from egxc.solver.scf import SelfConsistentFieldSolver
from egxc.training.loss.scalar import ScalarLossConfig, scalar_loss
from egxc.training.optimizer import OptConfig, get_optimizer
from egxc.utils.checkpointing import CheckpointManager
from egxc.utils.logging import Logger
from egxc.utils.typing import Alignment, BaseInitialGuess, NnParams

compilation_cache.set_cache_dir('./caches/jax_compile/deixc_finetune_energy')

jax_config.update('jax_platform_name', 'gpu')
jax_config.update('jax_enable_x64', True)
jax_config.update('jax_default_matmul_precision', 'float32')

ex = Experiment()


@ex.config
def default_config():
    """Default Sacred/SEML configuration for energy-only finetuning."""
    logging = {
        'project': 'DEI-XC',
        'dir': 'ANONYMOUS_DIR',
        'name': None,
        'output_csv': 'results.csv',
    }
    model = {
        'name': None,
    }
    base = {
        'use_density_fitting': True,
        'spin_restricted': True,
    }
    # Checkpoint to load
    evaluation_checkpoints = {
        'chkp_type': 'dynamic_train',
        'root_dir': None,
        'name': None,
        'indices': None,
    }
    basis = {
        'name': 'def2-SVP',
        'derivative': 1,
    }
    alignment = {
        'atom': 1,
        'basis': 1,
        'grid': 32 * 1024,
    }
    quadrature = {
        'level': 1,
    }
    initial_density_guess = {
        'base_initial_density_guess_key': 'minao',
        'initial_ref_density_method': None,
        'min_ref_density_interpolation': 0.0,
        'max_ref_density_interpolation': 0.0,
        'noise_eps': 0.0,
    }
    data = {
        'batch_size': 1,
        'shuffle': True,
        'workers': 3,
        'data_split_seed': 0,
        'preload': {
            'center': False,
        },
        'split': {'val_fraction': 0.1},
    }
    solver = {
        'method': 'scf',
        'convergence_threshold': 1e-7,
        'kwargs': {
            'cycles': 15,
            'convergence_acceleration_method': 'DIIS',
        },
    }
    # Energy shift (in Hartree) - added to target energy
    energy_shift = 0.0
    # Training config
    training = {
        'loss_measure': 'mse',
        'scale_per_electron': True,
    }
    # Optimizer config (matching deixc_main structure)
    optimizer = {
        'name': 'muon',
        'weight_decay': 0.0,  # we do not want to decay the pretraining
        'decay_only_graph_readout': False,
        'schedule': {
            'base_rate': 0.001,
            'min_rate': 0.0,
            'warmup_steps': 1000,
            'decay_steps': 500,  # small dataset
            'warmup_schedule': 'linear',
            'decay_schedule': 'inverse_time_decay',
        },
        'plateau_handling': None,
        'metropolis_stabilizer': None,
        'lookahead': None,
        'apply_every': 1,
        'clip_grad_max_norm': 0.2,
        'skip_nans': 3,
        'additional_params': None,
        'epochs': 10_000,
        'ema_decay': 0.0,
        'early_stopping_patience': 50,
        'early_stopping_min_relative_improvement': 1e-4,
        'restart_epochs': [],
        'restart_lr_scales': [],
    }
    # Checkpointing
    checkpointing = {
        'save_every': 10,
        'directory': None,
    }


@ex.named_config
def scaled_nagai2020():
    model = {
        'name': 'nagai2020',
        'local_n_layers': 4,
        'local_hidden_dim': 23,
    }


class EnergyFinetuningExperiment:
    """Finetuning experiment using only energy loss against dataset labels."""

    output_dir: Path
    output_csv: Path
    metrics_summary_path: Path

    basis_name: str
    basis_derivative: int
    use_density_fitting: bool

    dataset: dataloading.BaseDataset
    dataloaders: dataloading.DataLoaders
    init_psys: systems.PreloadSystem

    preload_transform: Sequence[grain.Transformation]
    init_preloaded_basis_fns: discretization.PreloadedGTOBasis
    main_thread_transform: dataloading.ToJaxTransform

    alignment: Alignment
    grid_level: int
    spin_restricted: bool

    base_initial_density_guess_key: BaseInitialGuess
    solver: SelfConsistentFieldSolver
    _solver_apply: Callable[..., Any]

    # Training
    optimizer: optax.GradientTransformation
    loss_config: ScalarLossConfig
    energy_shift: float
    epochs: int
    save_every: int
    checkpoint_dir: Path
    # Logging
    logger: Logger
    checkpointer: CheckpointManager | None

    @ex.capture(prefix='logging')  # type: ignore
    def __init__(
        self,
        overwrite: int,
        project: str,
        dir: str,
        output_csv: str | None,
        name: str,
    ) -> None:
        self.output_dir = results_directory(dir, name, exists_ok=True)
        self.output_csv = results_metrics_csv_path(self.output_dir, output_csv)
        self.metrics_summary_path = results_summary_path(self.output_dir)
        run_name = f'{name}_{overwrite}'
        self.run_name = run_name
        self.logger = Logger(
            project,
            ex.current_run.config,  # type: ignore
            dir=dir,
            name=run_name,
        )
        self.set_all()
        self.get_initial_density_matrix_fn()  # type: ignore
        self.init_checkpointer(
            run_name=run_name,
            **ex.current_run.config['checkpointing'],  # type: ignore
        )
        self.init_solver()  # type: ignore
        self.init_training()  # type: ignore
        self.init_checkpointing()  # type: ignore

    def set_all(self) -> None:
        self.set_base()  # type: ignore
        self.set_alignment()  # type: ignore
        self.set_quadrature()  # type: ignore
        self.set_basis()  # type: ignore
        self.set_evaluation_checkpoints()  # type: ignore
        self.set_energy_shift()  # type: ignore

    @ex.capture(prefix='evaluation_checkpoints')  # type: ignore
    def set_evaluation_checkpoints(
        self, chkp_type: str, root_dir: str, name: str, indices: List[int]
    ) -> None:
        self.chkp_type = chkp_type
        self.chkp_root_dir = root_dir
        self.chkp_name = name
        self.chkp_indices = indices

    @ex.capture(prefix='base')  # type: ignore
    def set_base(
        self,
        use_density_fitting: bool,
        spin_restricted: bool,
    ):
        self.spin_restricted = spin_restricted
        self.use_density_fitting = use_density_fitting

    @ex.capture(prefix='alignment')  # type: ignore
    def set_alignment(self, atom: int, basis: int, grid: int) -> None:
        self.alignment = Alignment(atom, basis, grid)

    @ex.capture(prefix='quadrature')  # type: ignore
    def set_quadrature(self, level: int) -> None:
        self.grid_level = level

    @ex.capture(prefix='basis')  # type: ignore
    def set_basis(self, name: str, derivative: int) -> None:
        self.basis_name = name
        self.basis_derivative = derivative

    @ex.capture  # type: ignore
    def set_energy_shift(self, energy_shift: float) -> None:
        self.energy_shift = energy_shift

    @ex.capture(prefix='initial_density_guess')  # type: ignore
    def get_initial_density_matrix_fn(
        self,
        base_initial_density_guess_key: BaseInitialGuess,
        initial_ref_density_method: Dict[str, Any] | None,
        min_ref_density_interpolation: float,
        max_ref_density_interpolation: float,
        noise_eps: float,
    ) -> None:
        self.base_initial_density_guess_key = base_initial_density_guess_key

        if initial_ref_density_method is None:
            self.initial_ref_density_method_key = None
            self.initial_ref_density_method_kwargs = None
        else:
            self.initial_ref_density_method_key = initial_ref_density_method['key']
            self.initial_ref_density_method_kwargs = initial_ref_density_method['kwargs']

        self.initial_density_matrix_fn = dataloading.get_initial_density_matrix_fn(
            min_ref_density_interpolation,
            max_ref_density_interpolation,
            noise_eps,
        )

        self._init_dataset()  # type: ignore

    @ex.capture(prefix='data')  # type: ignore
    def _init_dataset(
        self,
        key: str,
        data_set_kwargs: Dict[str, Any],
        data_split_seed: int,
        split: Dict[str, Any],
        preload: Dict[str, bool],
        batch_size: int,
        shuffle: bool,
        workers: int,
    ) -> None:
        self.data_split_seed = data_split_seed
        self.data_split = split
        self.preload = preload
        self.preload_center = preload['center']
        self.shuffle = shuffle
        self.batch_size = batch_size
        self.workers = workers

        # Use base dataset directly - no DEIXCDataset wrapper
        self.dataset = dataloading.key_to_dataset[key.lower()](
            initial_ref_density_method_key=self.initial_ref_density_method_key,
            initial_ref_density_method_kwargs=self.initial_ref_density_method_kwargs,
            **data_set_kwargs,
        )

        self._init_dataloader()
        self._init_main_thread_transform()

    def _init_dataloader(self) -> None:
        dataset_ensemble = DatasetEnsemble.infer_split(
            self.dataset, data_split_seed=self.data_split_seed, **self.data_split
        )

        self.max_angular_momentum, basis_fn_preloader = discretization.get_gto_preloader(
            self.basis_name, self.dataset.unique_elements
        )
        self.preload_transform = dataloading.get_preload_transform(
            self.batch_size,
            self.basis_name,
            self.spin_restricted,
            self.alignment,
            self.use_density_fitting,
            base_initial_density_guess=self.base_initial_density_guess_key,
            basis_fn_preloader=basis_fn_preloader,
            center=self.preload_center,
        )

        self.init_psys, self.dataloaders = get_psys_and_dataloaders(
            datasets=dataset_ensemble,
            transformations=self.preload_transform,
            shuffle=self.shuffle,
            shuffling_seed=self.data_split_seed,
            workers=self.workers,
            worker_buffer_size=1,
        )
        self.init_preloaded_basis_fns = basis_fn_preloader(self.init_psys.atom_z)

    def _init_main_thread_transform(self) -> None:
        grid_fn = discretization.get_grid_fn(
            self.grid_level, self.dataset.unique_elements, self.alignment.grid
        )
        basis_fn = discretization.get_gto_grid_eval_fn(
            deriv=self.basis_derivative,
            max_angular_momentum=self.max_angular_momentum,
        )
        self.main_thread_transform = dataloading.get_jax_transform(
            grid_fn,
            basis_fn,
        )

    @ex.capture
    def init_solver(self, solver: Dict[str, Any], model: Dict[str, Any]) -> None:
        model_module = self._xc_module(model)  # type: ignore
        solver_type = solver['method']
        solver_kwargs = solver['kwargs']
        if solver_type != 'scf':
            raise ValueError(f'Unsupported solver type: {solver_type}')
        self.convergence_threshold = solver['convergence_threshold']
        self.solver = SelfConsistentFieldSolver(
            model_module,
            use_density_fitting=self.use_density_fitting,
            spin_restricted=self.spin_restricted,
            **solver_kwargs,
        )
        self._solver_apply = jax.jit(self.solver.apply)

    @ex.capture
    def _xc_module(self, model) -> xc_energy.XCModule:
        functional = xc_energy.functionals.get_functional(**model)
        requires_spin_resolved = getattr(
            functional, 'requires_spin_resolved_features', False
        )
        feature_fn = xc_energy.features.DensityFeatures(
            self.spin_restricted, spin_resolved=requires_spin_resolved
        )
        module = xc_energy.XCModule(functional, feature_fn)
        return module

    @ex.capture(prefix='training')  # type: ignore
    def init_training(
        self,
        loss_measure: str,
        scale_per_electron: bool,
    ) -> None:
        self.loss_config = ScalarLossConfig(
            measure=loss_measure,  # type: ignore
            scale_per_electron=scale_per_electron,
        )

    @ex.capture
    def init_optimizer(self, optimizer: Dict[str, Any]) -> optax.GradientTransformation:
        opt_config = OptConfig.create(**optimizer)
        self.epochs = opt_config.epochs
        return get_optimizer(opt_config)

    @ex.capture(prefix='model')  # type: ignore
    def init_checkpointer(
        self, name: str, run_name: str, directory: str | None, **_
    ) -> None:
        if directory is None:
            raise ValueError('checkpointing.directory must be set for finetuning')
        self.checkpointer = CheckpointManager(
            directory,
            model_name=name,
            basis=self.basis_name,
            name=run_name,
            data_split_seed=self.data_split_seed,
        )
        self.checkpointer.save_config(ex.current_run.config)  # type: ignore

    @ex.capture(prefix='checkpointing')  # type: ignore
    def init_checkpointing(self, save_every: int, directory: str | None) -> None:
        self.save_every = save_every

    def get_checkpoint_name(self, idx: int) -> str:
        return f'{self.chkp_name}_{idx}'

    def get_checkpoint_path(self, idx: int) -> str:
        checkpoint_dir = Path(self.chkp_root_dir) / self.get_checkpoint_name(idx)
        split_dirs = sorted(checkpoint_dir.glob('split_0*'))
        if not split_dirs:
            return str(checkpoint_dir / 'split_0')
        return str(split_dirs[-1])

    def load_params(self, idx: int) -> NnParams:
        checkpoint_file_path = Path(
            checkpoint_best_path(self.get_checkpoint_path(idx), self.chkp_type)
        )
        if not checkpoint_file_path.exists():
            raise FileNotFoundError(f'Checkpoint not found at {checkpoint_file_path}')
        with checkpoint_file_path.open('rb') as stream:
            params = pickle.load(stream)
        return params

    def train_step(
        self,
        params: NnParams,
        opt_state: optax.OptState,
        psys: systems.PreloadSystem,
        basis_fns,
        targets: Targets,
    ) -> Tuple[NnParams, optax.OptState, float, float, float, float]:
        """Single training step. Returns (params, opt_state, loss, energy_diff_mEh, energy_mae_mEh, volatility_mEh)."""
        density_matrices, sys = self.main_thread_transform(psys, basis_fns)
        initial_dm, _ = self.initial_density_matrix_fn(density_matrices, None)

        # Target with constant shift
        target_energy = targets.energy.item() + self.energy_shift  # type: ignore

        def loss_fn(params):
            (e_hj, e_xc), _ = self._solver_apply(params, initial_dm, sys)
            total_pred = (
                e_xc[-1] + e_hj[-1] + systems.nuclear_energy_fn(sys._nuc_pos, sys)
            )
            volatility_mEh = jnp.abs((e_xc[-1] + e_hj[-1]) - (e_xc[-2] + e_hj[-2])) * 1e3
            return scalar_loss(
                jnp.array(target_energy),
                total_pred,
                sys.n_electrons,
                self.loss_config,
            ), (total_pred, volatility_mEh)

        (loss, (total_pred, volatility_mEh)), grads = jax.value_and_grad(
            loss_fn, has_aux=True
        )(params)
        updates, opt_state = self.optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        # Compute energy metrics in mEh
        energy_diff_mEh = float((total_pred - target_energy) * 1e3)
        energy_mae_mEh = float(jnp.abs(total_pred - target_energy) * 1e3)

        return (
            params,
            opt_state,
            float(loss),
            energy_diff_mEh,
            energy_mae_mEh,
            float(volatility_mEh),
        )

    def train_epoch(
        self,
        params: NnParams,
        opt_state: optax.OptState,
    ) -> Tuple[NnParams, optax.OptState, float, float]:
        """Train for one epoch. Returns (params, opt_state, mean_loss, mean_mae_mEh)."""
        self.logger.start_mean(
            [
                'train/loss',
                'train/energy_mae_mEh',
            ]
        )

        for psys, basis_fns, targets in self.dataloaders.train:
            params, opt_state, loss, energy_diff, mae, volatility = self.train_step(
                params, opt_state, psys, basis_fns, targets
            )

            # Log per-step metrics
            self.logger.log(
                {
                    'train/loss': loss,
                    'train/energy_diff_mEh': energy_diff,
                    'train/energy_mae_mEh': mae,
                    'debug/train/energy_volatility_mEh': volatility,
                },
            )

        mean_loss = self.logger.get_current_mean('train/loss')
        mean_mae = self.logger.get_current_mean('train/energy_mae_mEh')
        self.logger.stop_mean()
        return params, opt_state, mean_loss, mean_mae

    def evaluate_epoch(self, params: NnParams, split: str = 'val') -> Tuple[float, float]:
        """Evaluate on a split. Returns (mean_loss, mean_mae_mEh)."""
        self.logger.start_mean(
            [
                f'{split}/loss',
                f'{split}/energy_mae_mEh',
            ]
        )
        loader = getattr(self.dataloaders, split)

        for psys, basis_fns, targets in loader:
            density_matrices, sys = self.main_thread_transform(psys, basis_fns)
            initial_dm, _ = self.initial_density_matrix_fn(density_matrices, None)

            target_energy = targets.energy.item() + self.energy_shift

            (e_hj, e_xc), _ = self._solver_apply(params, initial_dm, sys)
            total_pred = (
                e_xc[-1] + e_hj[-1] + systems.nuclear_energy_fn(sys._nuc_pos, sys)
            )
            volatility_mEh = jnp.abs((e_xc[-1] + e_hj[-1]) - (e_xc[-2] + e_hj[-2])) * 1e3

            loss = scalar_loss(
                jnp.array(target_energy),
                total_pred,
                sys.n_electrons,
                self.loss_config,
            )

            # Compute energy metrics in mEh
            energy_diff_mEh = float((total_pred - target_energy) * 1e3)
            energy_mae_mEh = float(jnp.abs(total_pred - target_energy) * 1e3)

            # Log per-step metrics
            self.logger.log(
                {
                    f'{split}/loss': float(loss),
                    f'{split}/energy_diff_mEh': energy_diff_mEh,
                    f'{split}/energy_mae_mEh': energy_mae_mEh,
                    f'debug/{split}/energy_volatility_mEh': float(volatility_mEh),
                },
            )

        mean_loss = self.logger.get_current_mean(f'{split}/loss')
        mean_mae = self.logger.get_current_mean(f'{split}/energy_mae_mEh')
        self.logger.stop_mean()
        return mean_loss, mean_mae

    def __call__(self) -> None:
        """Run finetuning."""
        print(f'\n{"=" * 80}')
        print('Energy-only finetuning')
        print(f'Energy shift: {self.energy_shift} Eh')
        print(f'{"=" * 80}\n')

        # Load pretrained params
        idx = self.chkp_indices[0]
        params = self.load_params(idx)
        print(f'Loaded checkpoint: {self.get_checkpoint_name(idx)}')

        # Initialize optimizer
        self.optimizer = self.init_optimizer()  # type: ignore
        opt_state = self.optimizer.init(params)

        # Training loop
        best_val_loss = float('inf')
        training_records = []

        for epoch in range(self.epochs):
            params, opt_state, train_loss, train_mae = self.train_epoch(params, opt_state)
            val_loss, val_mae = self.evaluate_epoch(params, 'val')

            print(
                f'Epoch {epoch:4d}: train_loss={train_loss:.6f}, train_mae={train_mae:.3f} mEh | '
                f'val_loss={val_loss:.6f}, val_mae={val_mae:.3f} mEh'
            )

            training_records.append(
                {
                    'epoch': epoch,
                    'train_loss': train_loss,
                    'train_energy_mae_mEh': train_mae,
                    'val_loss': val_loss,
                    'val_energy_mae_mEh': val_mae,
                }
            )

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                if self.checkpointer is not None:
                    self.checkpointer.save_best_params(params, prefix='finetune')
                print(f'  New best val_loss: {val_loss:.6f}')

            # Log epoch summaries (means over the epoch)
            self.logger.log(
                {
                    'best_val_loss': best_val_loss,
                }
            )

        # Save training history
        df = pd.DataFrame(training_records)
        df.to_csv(self.output_csv, index=False)

        print(f'\n{"=" * 80}')
        print('Finetuning complete.')
        print(f'Best val_loss: {best_val_loss:.6f}')
        print(f'Results saved to: {self.output_dir}')
        print(f'{"=" * 80}')


@ex.automain
def main(overwrite: int):
    experiment = EnergyFinetuningExperiment(overwrite)  # type: ignore
    experiment()
