#!/usr/bin/env python3
"""TDDFT/TDA demo linearized around a DEI-XC (ML) ground state.

This script:
- Loads a trained DEI-XC checkpoint (pickled Flax params + YAML config).
- Builds a molecular `System` consistent with the checkpoint (basis/grid/alignment).
- Runs the project's SCF machinery (DEI-XC) to obtain the ground-state density.
- Builds TDA / full TDDFT Casida matrix-vector products at that reference point
  and solves for a few lowest excitation energies with Davidson iterations.

Notes:
- The reference point is the converged DEI-XC density.
"""

from __future__ import annotations

import argparse
import os
from pathlib import Path
from typing import Any, Dict, Tuple

# Default to CPU to avoid large-constant allocation issues on GPU for full 4-index ERIs.
# Override by setting `JAX_PLATFORM_NAME=gpu` in the environment.
os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu')
# Silence XLA compilation warnings
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '3')

import time
import traceback

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import wandb
import yaml
from pyscf import dft
from pyscf import scf as pyscf_scf

jax.config.update('jax_enable_x64', True)

from deixc.scf import DerivativeInformedSelfConsistentFieldSolver
from egxc.dataloading import QM9
from egxc.dataloading.io import unpickle_dictionary
from egxc.discretization import (
    GTOBasis,
    get_grid_fn,
    get_gto_grid_eval_fn,
    get_gto_preloader,
)
from egxc.systems import Grid, System, examples, nuclear_energy_fn
from egxc.systems.preload import preload_system_using_pyscf
from egxc.utils.linalg import modified_generalized_eigenvalue_problem
from egxc.utils.typing import Alignment, cast_to_integer_tuple
from egxc.xc_energy import DensityFeatures
from egxc.xc_energy.functionals import get_functional
from egxc.xc_energy.functionals.classical import BaseRangeSeparatedHybrid
from egxc.xc_energy.xc_module import XCModule
from tddft import Davidson, Davidson_Casida, build_cassida_mv

Hartree_to_eV = 27.211385050


def _save_results_to_csv(results: list[Dict[str, Any]], output_path: Path):
    """Save TDDFT results to CSV in long format.

    Args:
        results: List of result dictionaries from _run_tddft_for_molecule.
        output_path: Path to save CSV file.
    """
    rows = []
    for result in results:
        checkpoint_id = result['checkpoint_id']
        molecule_idx = result['molecule_idx']
        status = result['status']
        pipeline = result['pipeline']
        scf_volatility = result['scf_volatility']
        scf_final_energy = result['scf_final_energy_Ha']
        n_atoms = result['n_atoms']
        elements = result['elements']
        time_scf = result['time_scf_s']
        # Config metadata
        config_path = result['config_path']
        loss_forces = result['static_loss.relative_weights.forces']
        loss_orb_rot_grad = result[
            'static_loss.relative_weights.orbital_rotation_gradient'
        ]
        loss_xc_energy = result['static_loss.relative_weights.xc_energy']
        loss_xc_potential = result['static_loss.relative_weights.xc_potential']

        # Dynamic loss weights (present in checkpoint config)
        dyn_loss_density = result['dynamic_loss.relative_weights.density']
        dyn_loss_forces = result['dynamic_loss.relative_weights.forces']
        dyn_loss_orb_rot_grad = result[
            'dynamic_loss.relative_weights.orbital_rotation_gradient'
        ]
        dyn_loss_orb_rot_hessian = result[
            'dynamic_loss.relative_weights.orbital_rotation_hessian'
        ]
        dyn_loss_total_energy = result['dynamic_loss.relative_weights.total_energy']
        dyn_loss_xc_energy = result['dynamic_loss.relative_weights.xc_energy']
        dyn_loss_xc_potential = result['dynamic_loss.relative_weights.xc_potential']
        base_seed = result['base.seed']
        run_seed = result['run_seed']
        seed = result['seed']

        # Shared metadata for CSV rows
        _metadata = {
            'checkpoint_id': checkpoint_id,
            'molecule_idx': molecule_idx,
            'pipeline': pipeline,
            'status': status,
            'scf_volatility': scf_volatility,
            'scf_final_energy_Ha': scf_final_energy,
            'n_atoms': n_atoms,
            'elements': elements,
            'time_scf_s': time_scf,
            'config_path': config_path,
            'static_loss.relative_weights.forces': loss_forces,
            'static_loss.relative_weights.orbital_rotation_gradient': loss_orb_rot_grad,
            'static_loss.relative_weights.xc_energy': loss_xc_energy,
            'static_loss.relative_weights.xc_potential': loss_xc_potential,
            'dynamic_loss.relative_weights.density': dyn_loss_density,
            'dynamic_loss.relative_weights.forces': dyn_loss_forces,
            'dynamic_loss.relative_weights.orbital_rotation_gradient': dyn_loss_orb_rot_grad,
            'dynamic_loss.relative_weights.orbital_rotation_hessian': dyn_loss_orb_rot_hessian,
            'dynamic_loss.relative_weights.total_energy': dyn_loss_total_energy,
            'dynamic_loss.relative_weights.xc_energy': dyn_loss_xc_energy,
            'dynamic_loss.relative_weights.xc_potential': dyn_loss_xc_potential,
            'base.seed': base_seed,
            'run_seed': run_seed,
            'seed': seed,
        }

        # Add TDA energies
        if 'tda_energies_eV' in result:
            tda_energies = result['tda_energies_eV']
            time_tda = result['time_tda_s']
            for state_idx, energy in enumerate(tda_energies):
                rows.append(
                    {
                        **_metadata,
                        'state_index': state_idx,
                        'method': 'TDA',
                        'energy_eV': energy,
                        'time_davidson_s': time_tda,
                    }
                )
        else:
            tda_energies = None

        # Add TDDFT energies
        tddft_energies = result['tddft_energies_eV']
        time_tddft = result['time_tddft_s']
        for state_idx, energy in enumerate(tddft_energies):
            rows.append(
                {
                    **_metadata,
                    'state_index': state_idx,
                    'method': 'TDDFT',
                    'energy_eV': energy,
                    'time_davidson_s': time_tddft,
                }
            )

        # If status is failed, add at least one row with NaN energies
        if status == 'failed' and not tda_energies and not tddft_energies:
            rows.append(
                {
                    **_metadata,
                    'state_index': 0,
                    'method': 'TDA',
                    'energy_eV': np.nan,
                    'time_davidson_s': np.nan,
                }
            )

    df = pd.DataFrame(rows)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(output_path, index=False)
    print(f'\nSaved results to {output_path}')


def _load_checkpoint_config(checkpoint_dir: Path) -> Tuple[Dict[str, Any], Path]:
    """Load the single YAML config shipped alongside an evaluation checkpoint.

    Args:
        checkpoint_dir: Directory containing exactly one `*.yaml` file.

    Returns:
        Tuple of (parsed YAML config as a Python dict, path to the YAML file).
    """
    yamls = sorted(checkpoint_dir.glob('*.yaml'))
    assert len(yamls) == 1, f'Expected exactly one YAML in {checkpoint_dir}, got {yamls}'
    return yaml.safe_load(yamls[0].read_text()), yamls[0]


def _extract_xc_module_params(solver_params: Any) -> Any:
    """Extract `xc_module` params from a solver checkpoint if nested.

    Some DEI-XC runs checkpoint the full solver module, whose params tree contains a
    `xc_module` submodule. The TDDFT only needs the XC parameters.

    Args:
        solver_params: The deserialized checkpoint payload.

    Returns:
        A Flax params dict suitable for `XCModule.apply(...)`.
    """
    # In DEI-XC training, the checkpoint typically stores params for the solver module,
    # which contains a submodule named "xc_module".
    if isinstance(solver_params, dict) and 'params' in solver_params:
        root = solver_params['params']
        if isinstance(root, dict) and 'xc_module' in root:
            return {'params': root['xc_module']}
    return solver_params


def _scan_reference_from_custom_scf(
    *,
    sys,
    basis: str,
    cycles: int,
    xc_module: XCModule,
    use_density_fitting: bool = False,
):
    mol = sys.to_pyscf(basis)
    dm0 = pyscf_scf.RKS(mol).get_init_guess()

    scf_solver = DerivativeInformedSelfConsistentFieldSolver(
        xc_module=xc_module,
        spin_restricted=True,
        cycles=int(cycles),
        use_density_fitting=use_density_fitting,
        convergence_acceleration_method='DIIS',
    )
    params_solver = scf_solver.init(jax.random.PRNGKey(0), jnp.asarray(dm0), sys)
    xc_params = _extract_xc_module_params(params_solver)
    scf_apply = jax.jit(scf_solver.apply)

    (e_hj, e_xc), (_C_traj, P_traj, F_traj, _Vxc_traj) = scf_apply(
        params_solver, jnp.asarray(dm0), sys
    )
    P_ref = np.asarray(P_traj[-1])
    F_ref = F_traj[-1]

    eps, C = modified_generalized_eigenvalue_problem(
        F_ref, sys.fock_tensors.diagonal_overlap
    )
    eps = np.asarray(eps)
    C = np.asarray(C)
    occ = np.asarray(sys.fock_tensors.occupancies)
    occidx = np.where(occ > 0)[0]
    viridx = np.where(occ == 0)[0]
    orbo, orbv = C[:, occidx], C[:, viridx]
    e_ia = eps[viridx] - eps[occidx, None]
    hdiag = e_ia.ravel()
    e_nuc = float(nuclear_energy_fn(sys._nuc_pos, sys))
    total_e = float(np.asarray(e_hj[-1] + e_xc[-1])) + e_nuc

    return P_ref, orbo, orbv, e_ia, hdiag, xc_params, total_e


def _build_system_from_sample(
    nuc_pos: np.ndarray,
    atom_z: np.ndarray,
    basis: str,
    grid_level: int,
    alignment: Alignment,
    use_density_fitting: bool = False,
) -> System:
    """Build System from QM9 sample geometry.

    Args:
        nuc_pos: Nuclear positions (Angstrom)
        atom_z: Atomic numbers
        basis: Basis set name
        grid_level: Grid level
        alignment: Alignment settings

    Returns:
        System object
    """
    num_heavy_atom = len(atom_z) - np.sum(atom_z == 1)
    print(
        f'Building system with: natoms={len(atom_z)}, nheavy_atoms={num_heavy_atom}, elements={np.unique(atom_z)}, basis={basis}, grid_level={grid_level}, alignment={alignment}'
    )
    # Sort atoms by atomic number for consistency
    order = np.argsort(atom_z, stable=True)
    nuc_pos = nuc_pos[order]
    atom_z = atom_z[order]

    # Build preloaded system
    psys = preload_system_using_pyscf(
        idx=-1,
        nuc_pos=nuc_pos,
        atom_z=atom_z,
        charge=0,
        spin=0,
        reference_density=None,
        basis=basis,
        spin_restricted=True,
        alignment=alignment,
        base_initial_density_guess='minao',
        center=False,
        use_density_fitting=bool(use_density_fitting),
        cache_pyscf_mole=False,
        range_separation=None,
    )

    # Build grid
    grid_fn = get_grid_fn(grid_level, set(psys.atom_z[psys.atom_mask]), alignment.grid)
    coords, weights = grid_fn(
        psys.nuc_pos[psys.atom_mask],
        cast_to_integer_tuple(psys.atom_z[psys.atom_mask]),
    )

    l_max, pvec_basis_fn_factory = get_gto_preloader(
        basis, set(psys.atom_z[psys.atom_mask])
    )
    preloaded_vec_basis_fns = pvec_basis_fn_factory(psys.atom_z[psys.atom_mask])
    vec_basis_fns = GTOBasis.from_preloaded(preloaded_vec_basis_fns)
    basis_fn = get_gto_grid_eval_fn(1, l_max)
    aos, grad_aos = basis_fn(
        coords,
        psys.nuc_pos[psys.atom_mask],
        vec_basis_fns.radial_primitives,
        vec_basis_fns.compile_statics,
    )
    grid = Grid.create(coords, weights, aos, grad_aos)

    sys = System.from_preloaded(psys, grid)
    return sys


def _discover_checkpoints(base_dir: Path) -> list[tuple[str, Path]]:
    """Discover all checkpoint directories in the base directory.

    Args:
        base_dir: Base directory containing checkpoint subdirectories (e.g., evaluations/1000/XCdiff_scan_qm5)

    Returns:
        List of (checkpoint_id, checkpoint_path) tuples
    """
    checkpoints = []

    # Try recursive search for deeply nested checkpoint directories
    for checkpoint_path in sorted(base_dir.rglob('checkpoint')):
        if checkpoint_path.is_dir():
            # Check if this checkpoint directory contains the required files
            yamls = list(checkpoint_path.glob('*.yaml'))
            flax_file = checkpoint_path / 'best_dynamic_train_params.flax'

            if yamls and flax_file.exists():
                # Generate checkpoint_id from the relative path
                # e.g., b3lyp/EGXC/baseline or b3lyp/NNmGGA/grad/NNmGGA2_b3lyp_qm7_58
                rel_path = checkpoint_path.relative_to(base_dir)
                # Remove the 'checkpoint' suffix from the path
                checkpoint_id = str(rel_path.parent).replace('/', '_')
                checkpoints.append((checkpoint_id, checkpoint_path))

    # Fallback to old behavior if no checkpoints found with recursive search
    if not checkpoints:
        for subdir in sorted(base_dir.iterdir()):
            if subdir.is_dir():
                checkpoint_path = subdir / 'checkpoint'
                if checkpoint_path.exists() and checkpoint_path.is_dir():
                    checkpoint_id = subdir.name
                    checkpoints.append((checkpoint_id, checkpoint_path))

    return checkpoints


def _run_tddft_for_molecule(
    sys,
    basis,
    cfg,
    params_solver,
    xc_params,
    nstates,
    conv_tol,
    checkpoint_id: str = '',
    molecule_idx: int = -1,
    use_density_fitting: bool = False,
    do_tda: bool = False,
):
    """Run TDDFT for a single system and return structured results.

    Args:
        sys: System object
        basis: Basis set name
        cfg: Checkpoint config dict
        params_solver: Solver parameters
        xc_params: XC module parameters
        nstates: Number of states to compute
        conv_tol: Convergence tolerance
        checkpoint_id: Identifier for the checkpoint
        molecule_idx: Index of the molecule in the dataset

    Returns:
        Dictionary with results including energies, timings, and metadata
    """

    # Extract atom metadata
    atom_z = np.asarray(sys.atom_z)
    atom_mask = np.asarray(sys.atom_mask)
    unique_elements = sorted(set(atom_z[atom_mask]))
    elements_str = ','.join(map(str, unique_elements))
    n_atoms = int(np.sum(atom_mask))

    result = {
        'checkpoint_id': checkpoint_id,
        'molecule_idx': molecule_idx,
        'n_atoms': n_atoms,
        'elements': elements_str,
        'status': 'success',
        'pipeline': 'DEI-XC',
    }

    # Reconstruct the XC module (architecture) from a checkpoint YAML config.
    functional = get_functional(**cfg['model'])
    if bool(use_density_fitting) and isinstance(functional, BaseRangeSeparatedHybrid):
        raise ValueError(
            'Density fitting is not supported for range-separated hybrid functionals '
            '(range-separated ERIs are not implemented with DF in this repo).'
        )
    requires_spin_resolved = getattr(functional, 'requires_spin_resolved_features', False)
    feature_fn = DensityFeatures(
        spin_restricted=True, spin_resolved=requires_spin_resolved
    )
    xc_module = XCModule(functional, feature_fn)

    solver_cfg = cfg['solver']
    cycles = int(solver_cfg['kwargs']['cycles'])
    scf_solver = DerivativeInformedSelfConsistentFieldSolver(
        xc_module=xc_module,
        spin_restricted=True,
        cycles=cycles,
        use_density_fitting=bool(use_density_fitting),
        convergence_acceleration_method='DIIS',
    )
    scf_apply = jax.jit(scf_solver.apply)

    time_start = time.time()
    mol = sys.to_pyscf(basis)
    dm0 = dft.RKS(mol).get_init_guess()
    (e_hj, e_xc), (C_traj, P_traj, F_traj, _Vxc_traj) = scf_apply(
        params_solver, jnp.asarray(dm0), sys
    )
    P_ref = P_traj[-1]
    F_ref = F_traj[-1]
    time_scf = time.time() - time_start
    result['time_scf_s'] = time_scf
    print(f'Time taken for SCF: {time_scf:.1f} seconds')

    # Check convergence
    e_volatility = float(abs(e_xc[-2] - e_xc[-1]))
    result['scf_volatility'] = e_volatility
    if e_volatility > 1e-3:
        print(f'WARNING: SCF not converged (volatility: {e_volatility * 1e3:.3f} mHa)')

    time_start = time.time()
    eps, C = modified_generalized_eigenvalue_problem(
        F_ref, sys.fock_tensors.diagonal_overlap
    )
    eps = np.asarray(eps)
    C = np.asarray(C)
    occ = np.asarray(sys.fock_tensors.occupancies)
    occidx = np.where(occ > 0)[0]
    viridx = np.where(occ == 0)[0]
    orbo, orbv = C[:, occidx], C[:, viridx]
    e_ia = eps[viridx] - eps[occidx, None]
    hdiag = e_ia.ravel()
    time_end = time.time()
    print(f'Time taken for eigenvalue problem: {time_end - time_start:.1f} seconds')

    print('DEI-XC SCF reference')
    e_nuc = float(nuclear_energy_fn(sys._nuc_pos, sys))
    total_e = float(np.asarray(e_hj[-1] + e_xc[-1])) + e_nuc
    result['scf_final_energy_Ha'] = total_e
    print('  SCF cycles:', cycles)
    print('  Final E_HJ + E_XC (Ha):', f'{float(total_e):.2e}')

    if do_tda:
        # print('Davidson with DEI-XC linear response (TDA)')
        time_start = time.time()
        tda_mv = build_cassida_mv(
            sys=sys,
            xc_module=xc_module,
            params=xc_params,
            occupied_orbs=jnp.asarray(orbo),
            virtual_orbs=jnp.asarray(orbv),
            e_ia=jnp.asarray(e_ia),
            P_ref=jnp.asarray(P_ref),
            spin_restricted=True,
            use_density_fitting=bool(use_density_fitting),
            tda_approx=True,
        )
        e_tda, _ = Davidson(
            lambda X: np.asarray(tda_mv(jnp.asarray(X))),
            hdiag,
            N_states=nstates,
            conv_tol=conv_tol,
        )
        time_tda = time.time() - time_start
        result['time_tda_s'] = time_tda
        result['tda_energies_eV'] = list(e_tda * Hartree_to_eV)
        print('  DEI-XC TDA energies (eV):', e_tda * Hartree_to_eV)
        print(f'Time taken for TDA: {time_tda:.1f} seconds')

    def mv_xy(X: np.ndarray, Y: np.ndarray):
        U1, U2 = tddft_mv(jnp.asarray(X), jnp.asarray(Y))
        return np.asarray(U1), np.asarray(U2)

    # print('Davidson with DEI-XC linear response (TDDFT)')
    time_start = time.time()
    tddft_mv = build_cassida_mv(
        sys=sys,
        xc_module=xc_module,
        params=xc_params,
        occupied_orbs=jnp.asarray(orbo),
        virtual_orbs=jnp.asarray(orbv),
        e_ia=jnp.asarray(e_ia),
        P_ref=jnp.asarray(P_ref),
        spin_restricted=True,
        use_density_fitting=bool(use_density_fitting),
        tda_approx=False,
    )
    e_tddft, _, _ = Davidson_Casida(mv_xy, hdiag, N_states=nstates, conv_tol=conv_tol)
    time_tddft = time.time() - time_start
    result['time_tddft_s'] = time_tddft
    result['tddft_energies_eV'] = list(e_tddft * Hartree_to_eV)
    print('  DEI-XC TDDFT energies (eV):', e_tddft * Hartree_to_eV)
    print(f'Time taken for TDDFT: {time_tddft:.1f} seconds')

    return result


def _run_scan_repo_for_molecule(
    sys,
    basis: str,
    grid_level: int,
    nstates: int,
    conv_tol: float,
    checkpoint_id: str = '',
    molecule_idx: int = -1,
    use_density_fitting: bool = False,
    do_tda: bool = False,
):
    """Run SCAN (repo SCF + repo TDDFT) for a single system and return structured results.

    Args:
        sys: System object
        basis: Basis set name
        grid_level: Grid level for SCAN
        nstates: Number of states to compute
        conv_tol: Convergence tolerance
        checkpoint_id: Identifier for the checkpoint
        molecule_idx: Index of the molecule in the dataset

    Returns:
        Dictionary with results including energies, timings, and metadata
    """

    # Extract atom metadata
    atom_z = np.asarray(sys.atom_z)
    atom_mask = np.asarray(sys.atom_mask)
    unique_elements = sorted(set(atom_z[atom_mask]))
    elements_str = ','.join(map(str, unique_elements))
    n_atoms = int(np.sum(atom_mask))

    result = {
        'checkpoint_id': checkpoint_id,
        'molecule_idx': molecule_idx,
        'n_atoms': n_atoms,
        'elements': elements_str,
        'status': 'success',
        'pipeline': 'SCAN-repo',
    }

    # Build SCAN functional and XC module
    functional = get_functional(
        'scan', spin_restricted=True, use_density_fitting=use_density_fitting
    )
    xc_module = XCModule(functional, DensityFeatures(spin_restricted=True))

    mol = sys.to_pyscf(basis)
    dm_init = pyscf_scf.RKS(mol).get_init_guess()
    xc_params = xc_module.init(jax.random.PRNGKey(0), jnp.asarray(dm_init), sys.grid)

    # Run custom SCF with SCAN
    time_start = time.time()
    cycles = 15
    (
        P_ref,
        orbo,
        orbv,
        e_ia,
        hdiag,
        xc_params,
        total_e,
    ) = _scan_reference_from_custom_scf(
        sys=sys,
        basis=basis,
        cycles=cycles,
        xc_module=xc_module,
        use_density_fitting=use_density_fitting,
    )
    time_scf = time.time() - time_start
    result['time_scf_s'] = time_scf
    result['scf_final_energy_Ha'] = total_e
    result['scf_volatility'] = np.nan  # Not tracked for SCAN
    print(f'SCAN-repo SCF final E (Ha): {total_e}')
    print(f'Time taken for SCAN-repo SCF: {time_scf:.1f} seconds')

    if do_tda:
        # Davidson TDA
        time_start = time.time()
        tda_mv = build_cassida_mv(
            sys=sys,
            xc_module=xc_module,
            params=xc_params,
            occupied_orbs=jnp.asarray(orbo),
            virtual_orbs=jnp.asarray(orbv),
            e_ia=jnp.asarray(e_ia),
            P_ref=jnp.asarray(P_ref),
            spin_restricted=True,
            use_density_fitting=bool(use_density_fitting),
            tda_approx=True,
        )
        # print('Davidson with SCAN-repo linear response (TDA)')
        e_tda, _ = Davidson(
            lambda X: np.asarray(tda_mv(jnp.asarray(X))),
            hdiag,
            N_states=nstates,
            conv_tol=conv_tol,
        )
        time_tda = time.time() - time_start
        result['time_tda_s'] = time_tda
        result['tda_energies_eV'] = list(e_tda * Hartree_to_eV)
        print('  SCAN-repo TDA energies (eV):', e_tda * Hartree_to_eV)
        print(f'Time taken for TDA: {time_tda:.1f} seconds')

    # Davidson TDDFT
    def mv_xy(X: np.ndarray, Y: np.ndarray):
        U1, U2 = tddft_mv(jnp.asarray(X), jnp.asarray(Y))
        return np.asarray(U1), np.asarray(U2)

    # print('Davidson with SCAN-repo linear response (TDDFT)')
    time_start = time.time()
    tddft_mv = build_cassida_mv(
        sys=sys,
        xc_module=xc_module,
        params=xc_params,
        occupied_orbs=jnp.asarray(orbo),
        virtual_orbs=jnp.asarray(orbv),
        e_ia=jnp.asarray(e_ia),
        P_ref=jnp.asarray(P_ref),
        spin_restricted=True,
        use_density_fitting=bool(use_density_fitting),
        tda_approx=False,
    )
    e_tddft, _, _ = Davidson_Casida(mv_xy, hdiag, N_states=nstates, conv_tol=conv_tol)
    time_tddft = time.time() - time_start
    result['time_tddft_s'] = time_tddft
    result['tddft_energies_eV'] = list(e_tddft * Hartree_to_eV)
    print('  SCAN-repo TDDFT energies (eV):', e_tddft * Hartree_to_eV)
    print(f'Time taken for TDDFT: {time_tddft:.1f} seconds')

    return result


def run_tddft_for_checkpoints(
    checkpoint_base_dir: str,
    nstates: int = 5,
    conv_tol: float = 1e-5,
    n_samples: int = 10,
    seed: int = 42,
    data_dir: str = 'ANONYMOUS_DIR',
    output_base_dir: str = './results_tddft/XCdiff_scan_qm5',
    use_density_fitting: bool = False,
    molecule: str | None = None,
    redo: bool = False,
    use_wandb: bool = False,
):
    """Run TDDFT calculations on multiple checkpoints and save results to CSV.

    Compare 3 pipelines:
    1) DEI-XC (repo SCF) + local TDDFT
    2) SCAN (repo SCF) + local TDDFT

    Args:
        checkpoint_base_dir: Base directory containing checkpoint subdirectories.
        nstates: Number of excitation energies to compute.
        conv_tol: Davidson residual tolerance.
        n_samples: Number of QM5 molecules to process per checkpoint.
        seed: Random seed for sampling QM5 molecules.
        data_dir: Path to datasets directory.
        output_base_dir: Base directory for output CSV files.
        molecule: Optional molecule name (e.g. "water") to use instead of QM5 dataset.
        use_wandb: Log results to Weights & Biases.
    """
    if use_wandb:
        wandb.init(
            project='anonymous',
            config={
                'checkpoint_base_dir': checkpoint_base_dir,
                'nstates': nstates,
                'conv_tol': conv_tol,
                'n_samples': n_samples,
                'seed': seed,
                'use_density_fitting': use_density_fitting,
                'molecule': molecule,
            },
        )

    # get one or multiple checkpoints
    base_path = Path(checkpoint_base_dir)
    if base_path.is_file():
        checkpoints = [(base_path.parent.parent.stem, base_path.parent)]
    else:
        checkpoints = _discover_checkpoints(base_path)

    if not checkpoints:
        print(f'No checkpoints found in {checkpoint_base_dir}')
        return

    print(f'Found {len(checkpoints)} checkpoints')

    if molecule is not None:
        # Single molecule mode (e.g. water)
        print(f'Using example molecule: {molecule}')
        sample_indices = [0]
        dataset = []
    else:
        # Load QM5 dataset once
        print('Loading QM5 dataset (heavy_atoms_thresh=5, exclude_fluorine=True)')
        dataset = QM9(
            data_dir=data_dir,
            heavy_atoms_thresh=5,
            exclude_fluorine=True,
            energy_unit='hartree',
            distance_unit='ang',
        )
        train_set, val_set, test_set = dataset.random_split(val_fraction=0.0, seed=0)
        dataset = train_set

        total_samples = len(dataset)
        print(f'Total QM5 samples: {total_samples}')

        # Sample indices (same for all checkpoints)
        np.random.seed(seed)
        sample_indices = np.random.choice(
            total_samples, min(n_samples, total_samples), replace=False
        )

    for ckpt_count, (checkpoint_id, checkpoint_path) in enumerate(checkpoints, 1):
        print(f'\n{"=" * 80}')
        print(f'[{ckpt_count}/{len(checkpoints)}] Processing checkpoint: {checkpoint_id}')
        print(f'{"=" * 80}')

        output_dir = Path(output_base_dir) / checkpoint_id
        output_path = output_dir / f'{checkpoint_id}_tddft_results.csv'
        if output_path.exists() and not redo:
            print(f'Results already exist for checkpoint {checkpoint_id}, skipping.')
            print_mae_statistics(output_path)
            continue

        cfg, cfg_path = _load_checkpoint_config(checkpoint_path)
        params_solver = unpickle_dictionary(
            str(checkpoint_path / 'best_dynamic_train_params.flax')
        )
        xc_params = _extract_xc_module_params(params_solver)

        # Extract metadata from config for CSV
        cfg_relative_path = str(cfg_path.resolve().relative_to(Path.cwd()))
        static_loss_weights = cfg['static_loss']['relative_weights']
        dynamic_loss_weights = cfg['dynamic_loss']['relative_weights']
        cfg_metadata = {
            'config_path': cfg_relative_path,
            'static_loss.relative_weights.forces': static_loss_weights['forces'],
            'static_loss.relative_weights.orbital_rotation_gradient': static_loss_weights[
                'orbital_rotation_gradient'
            ],
            'static_loss.relative_weights.xc_energy': static_loss_weights['xc_energy'],
            'static_loss.relative_weights.xc_potential': static_loss_weights[
                'xc_potential'
            ],
            'dynamic_loss.relative_weights.density': dynamic_loss_weights['density'],
            'dynamic_loss.relative_weights.forces': dynamic_loss_weights['forces'],
            'dynamic_loss.relative_weights.orbital_rotation_gradient': dynamic_loss_weights[
                'orbital_rotation_gradient'
            ],
            'dynamic_loss.relative_weights.orbital_rotation_hessian': dynamic_loss_weights[
                'orbital_rotation_hessian'
            ],
            'dynamic_loss.relative_weights.total_energy': dynamic_loss_weights[
                'total_energy'
            ],
            'dynamic_loss.relative_weights.xc_energy': dynamic_loss_weights['xc_energy'],
            'dynamic_loss.relative_weights.xc_potential': dynamic_loss_weights[
                'xc_potential'
            ],
            'base.seed': cfg['base']['seed'],
            'run_seed': cfg['run_seed'],
            'seed': cfg['seed'],
        }

        # Get system configuration
        basis = cfg['basis']['name']
        grid_level = int(cfg['quadrature']['level'])
        a = cfg['alignment']
        alignment = Alignment(int(a['atom']), int(a['basis']), int(a['grid']))

        results = []

        # Process each molecule
        for mol_count, idx in enumerate(sample_indices, 1):
            print(f'\n{"-" * 60}')
            print(f'[{mol_count}/{len(sample_indices)}] Molecule {idx}')
            print(f'{"-" * 60}')
            start_time_molecule = time.time()

            if molecule is not None:
                sys = examples.get(
                    molecule,
                    basis=basis,
                    alignment=alignment,
                    use_density_fitting=bool(use_density_fitting),
                    spin_restricted=True,
                    include_grid=True,
                    grid_level=grid_level,
                )
            else:
                # Load sample
                _, (nuc_pos, atom_z, _, _, _), _ = dataset[idx]
                nuc_pos = np.asarray(nuc_pos)
                atom_z = np.asarray(atom_z)

                print(f'Atoms: {len(atom_z)}, Elements: {np.unique(atom_z)}')

                # Build system
                sys = _build_system_from_sample(
                    nuc_pos,
                    atom_z,
                    basis,
                    grid_level,
                    alignment,
                    use_density_fitting=use_density_fitting,
                )
                n_electrons = int(sys.n_electrons)
                print(f'Electrons: {n_electrons}, Basis: {basis}')

            # Pipeline 1: DEI-XC
            print('\n1) DEI-XC (repo SCF) + local TDDFT')
            result_deixc = _run_tddft_for_molecule(
                sys,
                basis,
                cfg,
                params_solver,
                xc_params,
                nstates,
                conv_tol,
                checkpoint_id=checkpoint_id,
                molecule_idx=idx,
                use_density_fitting=use_density_fitting,
            )
            result_deixc.update(cfg_metadata)
            results.append(result_deixc)
            log_dict = {f'deixc/{k}': v for k, v in result_deixc.items()}

            # Pipeline 2: SCAN-repo
            print('\n2) SCAN (repo SCF) + local TDDFT')
            result_scan_repo = _run_scan_repo_for_molecule(
                sys,
                basis,
                grid_level,
                nstates,
                conv_tol,
                checkpoint_id=checkpoint_id,
                molecule_idx=idx,
                use_density_fitting=use_density_fitting,
            )
            result_scan_repo.update(cfg_metadata)
            results.append(result_scan_repo)
            log_dict.update({f'scan_repo/{k}': v for k, v in result_scan_repo.items()})

            # Log all pipelines for this molecule as one wandb step
            if use_wandb:
                step = (ckpt_count - 1) * len(sample_indices) + mol_count - 1
                wandb.log(log_dict)

            end_time_molecule = time.time()
            print(
                f'Time taken for molecule {idx}: {int(end_time_molecule - start_time_molecule)} s'
            )

        print('\n', '-' * 20, sep='')

        # Save results for this checkpoint
        _save_results_to_csv(results, output_path)

        try:
            print_mae_statistics(output_path)
        except Exception as e:
            traceback.print_exc()
        print(f'\nCheckpoint {checkpoint_id} complete')

    # Aggregate all checkpoint CSVs into one combined file
    output_base = Path(output_base_dir)
    all_csvs = list(output_base.glob('*/*_tddft_results.csv'))
    if all_csvs:
        combined_df = pd.concat([pd.read_csv(f) for f in all_csvs], ignore_index=True)
        combined_path = output_base / 'all_checkpoints_combined.csv'
        combined_df.to_csv(combined_path, index=False)
        print(
            f'\nSaved combined results ({len(all_csvs)} checkpoints) to {combined_path}'
        )

        if use_wandb:
            wandb.save(str(combined_path))
            print(f'Logged combined CSV to wandb: {combined_path}')

    print('\nAll checkpoints complete')
    return


def print_mae_statistics(csv_path: Path):
    """Print MAE statistics for the given CSV file."""
    if not csv_path.exists():
        print(f'No results found for {csv_path}')
        return

    df = pd.read_csv(csv_path)

    print('df length', len(df))

    # Pivot to get energies by pipeline for comparison
    perpipeline: pd.DataFrame = df.pivot_table(
        index=['checkpoint_id', 'molecule_idx', 'state_index', 'method'],
        columns='pipeline',
        values='energy_eV',
    ).reset_index()
    print('perpipeline length', len(perpipeline))

    # Ground state energy differences
    scf_pivot: pd.DataFrame = df.drop_duplicates(
        subset=['checkpoint_id', 'molecule_idx', 'pipeline']
    ).pivot_table(
        index=['checkpoint_id', 'molecule_idx'],
        columns='pipeline',
        values='scf_final_energy_Ha',
    )
    print('Ground state energy scf_pivot length', len(scf_pivot))
    # if 'DEI-XC' in scf_pivot.columns:
    print('\nGround state energy differences (Ha):')
    if 'SCAN-repo' in scf_pivot.columns:
        diff_repo = scf_pivot['DEI-XC'] - scf_pivot['SCAN-repo']
        print(
            f'  DEI-XC - SCAN-repo: MAE={np.abs(diff_repo).mean():.2e}, RelMAE={np.abs(diff_repo / scf_pivot["SCAN-repo"]).mean() * 100:.2f}%'
        )

    print('\nExcitation energy differences (eV):')

    # # MAE between DEI-XC and SCAN-repo for TDA
    # tda = pivot[pivot['method'] == 'TDA'].dropna(subset=['DEI-XC', 'SCAN-repo'])
    # if len(tda) > 0:
    #     tda_diff = np.abs(tda['DEI-XC'] - tda['SCAN-repo'])
    #     tda_rel_diff = tda_diff / np.abs(tda['SCAN-repo'])
    #     print(f'\nDEI-XC vs SCAN-repo (TDA):')
    #     for state_idx in sorted(tda['state_index'].unique()):
    #         state_mask = tda['state_index'] == state_idx
    #         mae = tda_diff[state_mask].mean()
    #         relmae = tda_rel_diff[state_mask].mean() * 100
    #         print(f'  State {state_idx}: MAE = {mae:.2e} eV, RelMAE = {relmae:.2}%')
    #     print(f'  All states: MAE = {tda_diff.mean():.2e} eV, RelMAE = {tda_rel_diff.mean() * 100:.2f}%')

    # MAE between DEI-XC and SCAN-repo for TDDFT
    tddft = perpipeline[perpipeline['method'] == 'TDDFT'].dropna(
        subset=['DEI-XC', 'SCAN-repo']
    )
    print('tddft length', len(tddft))
    print('tddft DEI-XC length', len(tddft[tddft['DEI-XC'].notna()]))
    print('tddft SCAN-repo length', len(tddft[tddft['SCAN-repo'].notna()]))
    if len(tddft) > 0:
        tddft_diff = np.abs(tddft['DEI-XC'] - tddft['SCAN-repo'])
        tddft_rel_diff = tddft_diff / np.abs(tddft['SCAN-repo'])
        print('\nDEI-XC vs SCAN-repo (TDDFT):')
        for state_idx in sorted(tddft['state_index'].unique()):
            state_mask = tddft['state_index'] == state_idx
            mae = tddft_diff[state_mask].mean()
            relmae = tddft_rel_diff[state_mask].mean() * 100
            print(f'  State {state_idx}: MAE = {mae:.2e} eV, RelMAE = {relmae:.2f}%')
        print(
            f'  All states: MAE = {tddft_diff.mean():.2e} eV, RelMAE = {tddft_rel_diff.mean() * 100:.2f}%'
        )

    return


def _parse_args():
    """Parse CLI args for the DEI-XC TDDFT eval."""
    p = argparse.ArgumentParser()
    p.add_argument(
        '--checkpoint-base-dir',
        type=str,
        default='./evaluations/1000/XCdiff_scan_qm5',
        # default='./evaluations/1000/XCdiff_scan_qm5/XCdiff_scan_qm5_-1/checkpoint/best_dynamic_train_params.flax',
        help='Base directory containing checkpoint subdirectories',
    )
    p.add_argument('--nstates', type=int, default=2)
    p.add_argument('--conv-tol', type=float, default=1e-5)
    p.add_argument(
        '--n-samples',
        type=int,
        default=10,
        help='Number of QM5 molecules to process per checkpoint',
    )
    p.add_argument(
        '--seed',
        type=int,
        default=42,
        help='Random seed for QM5 sampling',
    )
    p.add_argument(
        '--data-dir',
        type=str,
        default='../Datastore/',
        # default='ANONYMOUS_DIR',
        help='Path to datasets directory',
    )
    p.add_argument(
        '--output-base-dir',
        type=str,
        default='./results_tddft',
        help='Base directory for output CSV files',
    )
    p.add_argument(
        '--redo',
        action='store_true',
        help='Redo existing results.',
    )
    p.add_argument(
        '--molecule',
        type=str,
        default=None,
        help='Optional molecule name (e.g. "water") to use instead of QM5 dataset.',
    )
    p.add_argument(
        '--wandb',
        action='store_true',
        help='Log results to Weights & Biases.',
    )
    return p.parse_args()


"""
uv run scripts/run_tddft.py --checkpoint-base-dir ./evaluations/1000/XCdiff_scan_qm5/evaluations/1000/XCdiff_scan_qm5/XCdiff_scan_qm5_-1/results.csv --nstates 5

uv run scripts/run_tddft.py --checkpoint-base-dir evaluations/tddft/b3lyp --nstates 5 --n-samples 10 --output-base-dir ./results_tddft/tddft_b3lyp

Takes about 2 minutes per molecule geometry on CPU.
10 samples -> 20 minutes per checkpoint
Total QM5 samples: 174 -> 350 minutes = 5.8 hours per checkpoint
"""
if __name__ == '__main__':
    start_time_script = time.time()
    args = _parse_args()

    # set a directory name for the output
    dir_name = ''
    if Path(args.checkpoint_base_dir).is_file():
        dir_name += '/' + str(Path(args.checkpoint_base_dir).parent.parent.stem)
    else:
        dir_name += '/' + args.checkpoint_base_dir.split('/')[-1]
    dir_name += '/'
    if args.molecule is not None:
        dir_name += args.molecule
    else:
        dir_name += 'QM5'
    dir_name += '_n' + str(args.nstates)
    dir_name += '_s' + str(args.n_samples)
    output_base_dir = args.output_base_dir + dir_name

    run_tddft_for_checkpoints(
        checkpoint_base_dir=args.checkpoint_base_dir,
        nstates=args.nstates,
        conv_tol=args.conv_tol,
        n_samples=args.n_samples,
        seed=args.seed,
        data_dir=args.data_dir,
        output_base_dir=output_base_dir,
        use_density_fitting=True,  # =bool(args.use_density_fitting),
        molecule=args.molecule,
        redo=args.redo,
        use_wandb=args.wandb,
    )

    end_time_script = time.time()
    print(f'Time taken for script: {end_time_script - start_time_script:.1f} seconds')
