#!/usr/bin/env python3
"""TDDFT/TDA demo linearized around a DEI-XC (ML) ground state.

This script performs size extrapolation studies by testing on 10 molecules per heavy atom count
from the QM9 dataset (excluding fluorine).

This script:
- Loads a trained DEI-XC checkpoint (pickled Flax params + YAML config).
- Builds a molecular `System` consistent with the checkpoint (basis/grid/alignment).
- Runs the project's SCF machinery (DEI-XC) to obtain the ground-state density.
- Builds TDA / full TDDFT Casida matrix-vector products at that reference point
  and solves for a few lowest excitation energies with Davidson iterations.

Notes:
- The reference point is the converged DEI-XC density.
- This version samples 10 molecules per heavy atom count for size extrapolation analysis.
"""

from __future__ import annotations

import argparse
import os
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, Tuple

# Default to CPU to avoid large-constant allocation issues on GPU for full 4-index ERIs.
# Override by setting `JAX_PLATFORM_NAME=gpu` in the environment.
os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu')
# Silence XLA compilation warnings
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '3')

import time
import traceback

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import wandb
import yaml
from pyscf import dft
from pyscf import scf as pyscf_scf

jax.config.update('jax_enable_x64', True)

from deixc.scf import DerivativeInformedSelfConsistentFieldSolver
from egxc.dataloading import QM9
from egxc.dataloading.io import unpickle_dictionary
from egxc.discretization import (
    GTOBasis,
    get_grid_fn,
    get_gto_grid_eval_fn,
    get_gto_preloader,
)
from egxc.systems import Grid, System, examples, nuclear_energy_fn
from egxc.systems.preload import preload_system_using_pyscf
from egxc.utils.linalg import modified_generalized_eigenvalue_problem
from egxc.utils.typing import Alignment, cast_to_integer_tuple
from egxc.xc_energy import DensityFeatures
from egxc.xc_energy.functionals import get_functional
from egxc.xc_energy.functionals.classical import BaseRangeSeparatedHybrid
from egxc.xc_energy.xc_module import XCModule
from tddft import Davidson, Davidson_Casida, build_cassida_mv

Hartree_to_eV = 27.211385050

# Atomic number to element symbol mapping
ATOMIC_SYMBOLS = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F'}


def _atom_z_to_formula(atom_z: np.ndarray) -> str:
    """Convert atomic numbers to chemical formula string.

    Args:
        atom_z: Array of atomic numbers

    Returns:
        Chemical formula string (e.g., "C2H6O" for ethanol)
    """
    from collections import Counter

    # Count each element
    element_counts = Counter(atom_z)

    # Sort by atomic number (C, H, N, O, F)
    # Standard order: C first, H second, then alphabetical
    formula_parts = []

    # Carbon first
    if 6 in element_counts:
        count = element_counts[6]
        formula_parts.append(f'C{count if count > 1 else ""}')
        del element_counts[6]

    # Hydrogen second
    if 1 in element_counts:
        count = element_counts[1]
        formula_parts.append(f'H{count if count > 1 else ""}')
        del element_counts[1]

    # Rest alphabetically
    for z in sorted(element_counts.keys()):
        count = element_counts[z]
        symbol = ATOMIC_SYMBOLS.get(z, f'Z{z}')
        formula_parts.append(f'{symbol}{count if count > 1 else ""}')

    return ''.join(formula_parts)


def _save_results_to_csv(results: list[Dict[str, Any]], output_path: Path):
    """Save TDDFT results to CSV in long format.

    Args:
        results: List of result dictionaries from _run_tddft_for_molecule.
        output_path: Path to save CSV file.
    """
    rows = []
    for result in results:
        checkpoint_id = result['checkpoint_id']
        molecule_idx = result['molecule_idx']
        status = result['status']
        pipeline = result['pipeline']
        scf_volatility = result['scf_volatility']
        scf_final_energy = result['scf_final_energy_Ha']
        n_atoms = result['n_atoms']
        n_heavy_atoms = result['n_heavy_atoms']
        elements = result['elements']
        chemical_formula = result['chemical_formula']
        time_scf = result['time_scf_s']
        # Config metadata
        config_path = result['config_path']
        loss_forces = result['static_loss.relative_weights.forces']
        loss_orb_rot_grad = result[
            'static_loss.relative_weights.orbital_rotation_gradient'
        ]
        loss_xc_energy = result['static_loss.relative_weights.xc_energy']
        loss_xc_potential = result['static_loss.relative_weights.xc_potential']

        # Dynamic loss weights (present in checkpoint config)
        dyn_loss_density = result['dynamic_loss.relative_weights.density']
        dyn_loss_forces = result['dynamic_loss.relative_weights.forces']
        dyn_loss_orb_rot_grad = result[
            'dynamic_loss.relative_weights.orbital_rotation_gradient'
        ]
        dyn_loss_orb_rot_hessian = result[
            'dynamic_loss.relative_weights.orbital_rotation_hessian'
        ]
        dyn_loss_total_energy = result['dynamic_loss.relative_weights.total_energy']
        dyn_loss_xc_energy = result['dynamic_loss.relative_weights.xc_energy']
        dyn_loss_xc_potential = result['dynamic_loss.relative_weights.xc_potential']
        base_seed = result['base.seed']
        run_seed = result['run_seed']
        seed = result['seed']

        # Shared metadata for CSV rows
        _metadata = {
            'checkpoint_id': checkpoint_id,
            'molecule_idx': molecule_idx,
            'pipeline': pipeline,
            'status': status,
            'scf_volatility': scf_volatility,
            'scf_final_energy_Ha': scf_final_energy,
            'n_atoms': n_atoms,
            'n_heavy_atoms': n_heavy_atoms,
            'elements': elements,
            'chemical_formula': chemical_formula,
            'time_scf_s': time_scf,
            'config_path': config_path,
            'static_loss.relative_weights.forces': loss_forces,
            'static_loss.relative_weights.orbital_rotation_gradient': loss_orb_rot_grad,
            'static_loss.relative_weights.xc_energy': loss_xc_energy,
            'static_loss.relative_weights.xc_potential': loss_xc_potential,
            'dynamic_loss.relative_weights.density': dyn_loss_density,
            'dynamic_loss.relative_weights.forces': dyn_loss_forces,
            'dynamic_loss.relative_weights.orbital_rotation_gradient': dyn_loss_orb_rot_grad,
            'dynamic_loss.relative_weights.orbital_rotation_hessian': dyn_loss_orb_rot_hessian,
            'dynamic_loss.relative_weights.total_energy': dyn_loss_total_energy,
            'dynamic_loss.relative_weights.xc_energy': dyn_loss_xc_energy,
            'dynamic_loss.relative_weights.xc_potential': dyn_loss_xc_potential,
            'base.seed': base_seed,
            'run_seed': run_seed,
            'seed': seed,
        }

        # Add TDA energies
        if 'tda_energies_eV' in result:
            tda_energies = result['tda_energies_eV']
            time_tda = result['time_tda_s']
            for state_idx, energy in enumerate(tda_energies):
                rows.append(
                    {
                        **_metadata,
                        'state_index': state_idx,
                        'method': 'TDA',
                        'energy_eV': energy,
                        'time_davidson_s': time_tda,
                    }
                )
        else:
            tda_energies = None

        # Add TDDFT energies
        tddft_energies = result['tddft_energies_eV']
        time_tddft = result['time_tddft_s']
        for state_idx, energy in enumerate(tddft_energies):
            rows.append(
                {
                    **_metadata,
                    'state_index': state_idx,
                    'method': 'TDDFT',
                    'energy_eV': energy,
                    'time_davidson_s': time_tddft,
                }
            )

        # If status is failed, add at least one row with NaN energies
        if status == 'failed' and not tda_energies and not tddft_energies:
            rows.append(
                {
                    **_metadata,
                    'state_index': 0,
                    'method': 'TDA',
                    'energy_eV': np.nan,
                    'time_davidson_s': np.nan,
                }
            )

    df = pd.DataFrame(rows)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(output_path, index=False)
    print(f'\nSaved results to {output_path}')


def _load_checkpoint_config(checkpoint_dir: Path) -> Tuple[Dict[str, Any], Path]:
    """Load the single YAML config shipped alongside an evaluation checkpoint.

    Args:
        checkpoint_dir: Directory containing exactly one `*.yaml` file.

    Returns:
        Tuple of (parsed YAML config as a Python dict, path to the YAML file).
    """
    yamls = sorted(checkpoint_dir.glob('*.yaml'))
    assert len(yamls) == 1, f'Expected exactly one YAML in {checkpoint_dir}, got {yamls}'
    return yaml.safe_load(yamls[0].read_text()), yamls[0]


def _extract_xc_module_params(solver_params: Any) -> Any:
    """Extract `xc_module` params from a solver checkpoint if nested.

    Some DEI-XC runs checkpoint the full solver module, whose params tree contains a
    `xc_module` submodule. The TDDFT only needs the XC parameters.

    Args:
        solver_params: The deserialized checkpoint payload.

    Returns:
        A Flax params dict suitable for `XCModule.apply(...)`.
    """
    # In DEI-XC training, the checkpoint typically stores params for the solver module,
    # which contains a submodule named "xc_module".
    if isinstance(solver_params, dict) and 'params' in solver_params:
        root = solver_params['params']
        if isinstance(root, dict) and 'xc_module' in root:
            return {'params': root['xc_module']}
    return solver_params


def _scan_reference_from_custom_scf(
    *,
    sys,
    basis: str,
    cycles: int,
    xc_module: XCModule,
    use_density_fitting: bool = False,
):
    mol = sys.to_pyscf(basis)
    dm0 = pyscf_scf.RKS(mol).get_init_guess()

    scf_solver = DerivativeInformedSelfConsistentFieldSolver(
        xc_module=xc_module,
        spin_restricted=True,
        cycles=int(cycles),
        use_density_fitting=use_density_fitting,
        convergence_acceleration_method='DIIS',
    )
    params_solver = scf_solver.init(jax.random.PRNGKey(0), jnp.asarray(dm0), sys)
    xc_params = _extract_xc_module_params(params_solver)
    scf_apply = jax.jit(scf_solver.apply)

    (e_hj, e_xc), (_C_traj, P_traj, F_traj, _Vxc_traj) = scf_apply(
        params_solver, jnp.asarray(dm0), sys
    )
    P_ref = np.asarray(P_traj[-1])
    F_ref = F_traj[-1]

    eps, C = modified_generalized_eigenvalue_problem(
        F_ref, sys.fock_tensors.diagonal_overlap
    )
    eps = np.asarray(eps)
    C = np.asarray(C)
    occ = np.asarray(sys.fock_tensors.occupancies)
    occidx = np.where(occ > 0)[0]
    viridx = np.where(occ == 0)[0]
    orbo, orbv = C[:, occidx], C[:, viridx]
    e_ia = eps[viridx] - eps[occidx, None]
    hdiag = e_ia.ravel()
    e_nuc = float(nuclear_energy_fn(sys._nuc_pos, sys))
    total_e = float(np.asarray(e_hj[-1] + e_xc[-1])) + e_nuc

    return P_ref, orbo, orbv, e_ia, hdiag, xc_params, total_e


def _build_system_from_sample(
    nuc_pos: np.ndarray,
    atom_z: np.ndarray,
    basis: str,
    grid_level: int,
    alignment: Alignment,
    use_density_fitting: bool = False,
) -> System:
    """Build System from QM9 sample geometry.

    Args:
        nuc_pos: Nuclear positions (Angstrom)
        atom_z: Atomic numbers
        basis: Basis set name
        grid_level: Grid level
        alignment: Alignment settings

    Returns:
        System object
    """
    num_heavy_atom = len(atom_z) - np.sum(atom_z == 1)
    print(
        f'Building system with: natoms={len(atom_z)}, nheavy_atoms={num_heavy_atom}, elements={np.unique(atom_z)}, basis={basis}, grid_level={grid_level}, alignment={alignment}'
    )
    # Sort atoms by atomic number for consistency
    order = np.argsort(atom_z, stable=True)
    nuc_pos = nuc_pos[order]
    atom_z = atom_z[order]

    # Build preloaded system
    psys = preload_system_using_pyscf(
        idx=-1,
        nuc_pos=nuc_pos,
        atom_z=atom_z,
        charge=0,
        spin=0,
        reference_density=None,
        basis=basis,
        spin_restricted=True,
        alignment=alignment,
        base_initial_density_guess='minao',
        center=False,
        use_density_fitting=bool(use_density_fitting),
        cache_pyscf_mole=False,
        range_separation=None,
    )

    # Build grid
    grid_fn = get_grid_fn(grid_level, set(psys.atom_z[psys.atom_mask]), alignment.grid)
    coords, weights = grid_fn(
        psys.nuc_pos[psys.atom_mask],
        cast_to_integer_tuple(psys.atom_z[psys.atom_mask]),
    )

    l_max, pvec_basis_fn_factory = get_gto_preloader(
        basis, set(psys.atom_z[psys.atom_mask])
    )
    preloaded_vec_basis_fns = pvec_basis_fn_factory(psys.atom_z[psys.atom_mask])
    vec_basis_fns = GTOBasis.from_preloaded(preloaded_vec_basis_fns)
    basis_fn = get_gto_grid_eval_fn(1, l_max)
    aos, grad_aos = basis_fn(
        coords,
        psys.nuc_pos[psys.atom_mask],
        vec_basis_fns.radial_primitives,
        vec_basis_fns.compile_statics,
    )
    grid = Grid.create(coords, weights, aos, grad_aos)

    sys = System.from_preloaded(psys, grid)
    return sys


def _discover_checkpoints(base_dir: Path) -> list[tuple[str, Path]]:
    """Discover all checkpoint directories in the base directory.

    Args:
        base_dir: Base directory containing checkpoint subdirectories (e.g., evaluations/1000/XCdiff_scan_qm5)

    Returns:
        List of (checkpoint_id, checkpoint_path) tuples
    """
    checkpoints = []

    # Try recursive search for deeply nested checkpoint directories
    for checkpoint_path in sorted(base_dir.rglob('checkpoint')):
        if checkpoint_path.is_dir():
            # Check if this checkpoint directory contains the required files
            yamls = list(checkpoint_path.glob('*.yaml'))
            flax_file = checkpoint_path / 'best_dynamic_train_params.flax'

            if yamls and flax_file.exists():
                # Generate checkpoint_id from the relative path
                # e.g., b3lyp/EGXC/baseline or b3lyp/NNmGGA/grad/NNmGGA2_b3lyp_qm7_58
                rel_path = checkpoint_path.relative_to(base_dir)
                # Remove the 'checkpoint' suffix from the path
                checkpoint_id = str(rel_path.parent).replace('/', '_')
                checkpoints.append((checkpoint_id, checkpoint_path))

    # Fallback to old behavior if no checkpoints found with recursive search
    if not checkpoints:
        for subdir in sorted(base_dir.iterdir()):
            if subdir.is_dir():
                checkpoint_path = subdir / 'checkpoint'
                if checkpoint_path.exists() and checkpoint_path.is_dir():
                    checkpoint_id = subdir.name
                    checkpoints.append((checkpoint_id, checkpoint_path))

    return checkpoints


def _run_tddft_for_molecule(
    sys,
    basis,
    cfg,
    params_solver,
    xc_params,
    nstates,
    conv_tol,
    checkpoint_id: str = '',
    molecule_idx: int = -1,
    use_density_fitting: bool = False,
    do_tda: bool = False,
):
    """Run TDDFT for a single system and return structured results.

    Args:
        sys: System object
        basis: Basis set name
        cfg: Checkpoint config dict
        params_solver: Solver parameters
        xc_params: XC module parameters
        nstates: Number of states to compute
        conv_tol: Convergence tolerance
        checkpoint_id: Identifier for the checkpoint
        molecule_idx: Index of the molecule in the dataset

    Returns:
        Dictionary with results including energies, timings, and metadata
    """

    # Extract atom metadata
    atom_z = np.asarray(sys.atom_z)
    atom_mask = np.asarray(sys.atom_mask)
    unique_elements = sorted(set(atom_z[atom_mask]))
    elements_str = ','.join(map(str, unique_elements))
    n_atoms = int(np.sum(atom_mask))
    n_heavy_atoms = int(np.sum((atom_z[atom_mask] != 1)))
    chemical_formula = _atom_z_to_formula(atom_z[atom_mask])

    result = {
        'checkpoint_id': checkpoint_id,
        'molecule_idx': molecule_idx,
        'n_atoms': n_atoms,
        'n_heavy_atoms': n_heavy_atoms,
        'elements': elements_str,
        'chemical_formula': chemical_formula,
        'status': 'success',
        'pipeline': 'DEI-XC',
    }

    # Reconstruct the XC module (architecture) from a checkpoint YAML config.
    functional = get_functional(**cfg['model'])
    if bool(use_density_fitting) and isinstance(functional, BaseRangeSeparatedHybrid):
        raise ValueError(
            'Density fitting is not supported for range-separated hybrid functionals '
            '(range-separated ERIs are not implemented with DF in this repo).'
        )
    requires_spin_resolved = getattr(functional, 'requires_spin_resolved_features', False)
    feature_fn = DensityFeatures(
        spin_restricted=True, spin_resolved=requires_spin_resolved
    )
    xc_module = XCModule(functional, feature_fn)

    solver_cfg = cfg['solver']
    cycles = int(solver_cfg['kwargs']['cycles'])
    scf_solver = DerivativeInformedSelfConsistentFieldSolver(
        xc_module=xc_module,
        spin_restricted=True,
        cycles=cycles,
        use_density_fitting=bool(use_density_fitting),
        convergence_acceleration_method='DIIS',
    )
    scf_apply = jax.jit(scf_solver.apply)

    time_start = time.time()
    mol = sys.to_pyscf(basis)
    dm0 = dft.RKS(mol).get_init_guess()
    (e_hj, e_xc), (C_traj, P_traj, F_traj, _Vxc_traj) = scf_apply(
        params_solver, jnp.asarray(dm0), sys
    )
    P_ref = P_traj[-1]
    F_ref = F_traj[-1]
    time_scf = time.time() - time_start
    result['time_scf_s'] = time_scf
    print(f'Time taken for SCF: {time_scf:.1f} seconds')

    # Check convergence
    e_volatility = float(abs(e_xc[-2] - e_xc[-1]))
    result['scf_volatility'] = e_volatility
    if e_volatility > 1e-3:
        print(f'WARNING: SCF not converged (volatility: {e_volatility * 1e3:.3f} mHa)')

    time_start = time.time()
    eps, C = modified_generalized_eigenvalue_problem(
        F_ref, sys.fock_tensors.diagonal_overlap
    )
    eps = np.asarray(eps)
    C = np.asarray(C)
    occ = np.asarray(sys.fock_tensors.occupancies)
    occidx = np.where(occ > 0)[0]
    viridx = np.where(occ == 0)[0]
    orbo, orbv = C[:, occidx], C[:, viridx]
    e_ia = eps[viridx] - eps[occidx, None]
    hdiag = e_ia.ravel()
    time_end = time.time()
    print(f'Time taken for eigenvalue problem: {time_end - time_start:.1f} seconds')

    print('DEI-XC SCF reference')
    e_nuc = float(nuclear_energy_fn(sys._nuc_pos, sys))
    total_e = float(np.asarray(e_hj[-1] + e_xc[-1])) + e_nuc
    result['scf_final_energy_Ha'] = total_e
    print('  SCF cycles:', cycles)
    print('  Final E_HJ + E_XC (Ha):', f'{float(total_e):.2e}')

    if do_tda:
        # print('Davidson with DEI-XC linear response (TDA)')
        time_start = time.time()
        tda_mv = build_cassida_mv(
            sys=sys,
            xc_module=xc_module,
            params=xc_params,
            occupied_orbs=jnp.asarray(orbo),
            virtual_orbs=jnp.asarray(orbv),
            e_ia=jnp.asarray(e_ia),
            P_ref=jnp.asarray(P_ref),
            spin_restricted=True,
            use_density_fitting=bool(use_density_fitting),
            tda_approx=True,
        )
        e_tda, _ = Davidson(
            lambda X: np.asarray(tda_mv(jnp.asarray(X))),
            hdiag,
            N_states=nstates,
            conv_tol=conv_tol,
        )
        time_tda = time.time() - time_start
        result['time_tda_s'] = time_tda
        result['tda_energies_eV'] = list(e_tda * Hartree_to_eV)
        print('  DEI-XC TDA energies (eV):', e_tda * Hartree_to_eV)
        print(f'Time taken for TDA: {time_tda:.1f} seconds')

    def mv_xy(X: np.ndarray, Y: np.ndarray):
        U1, U2 = tddft_mv(jnp.asarray(X), jnp.asarray(Y))
        return np.asarray(U1), np.asarray(U2)

    # print('Davidson with DEI-XC linear response (TDDFT)')
    time_start = time.time()
    tddft_mv = build_cassida_mv(
        sys=sys,
        xc_module=xc_module,
        params=xc_params,
        occupied_orbs=jnp.asarray(orbo),
        virtual_orbs=jnp.asarray(orbv),
        e_ia=jnp.asarray(e_ia),
        P_ref=jnp.asarray(P_ref),
        spin_restricted=True,
        use_density_fitting=bool(use_density_fitting),
        tda_approx=False,
    )
    e_tddft, _, _ = Davidson_Casida(mv_xy, hdiag, N_states=nstates, conv_tol=conv_tol)
    time_tddft = time.time() - time_start
    result['time_tddft_s'] = time_tddft
    result['tddft_energies_eV'] = list(e_tddft * Hartree_to_eV)
    print('  DEI-XC TDDFT energies (eV):', e_tddft * Hartree_to_eV)
    print(f'Time taken for TDDFT: {time_tddft:.1f} seconds')

    return result


def _run_scan_repo_for_molecule(
    sys,
    basis: str,
    grid_level: int,
    nstates: int,
    conv_tol: float,
    checkpoint_id: str = '',
    molecule_idx: int = -1,
    use_density_fitting: bool = False,
    do_tda: bool = False,
):
    """Run SCAN (repo SCF + repo TDDFT) for a single system and return structured results.

    Args:
        sys: System object
        basis: Basis set name
        grid_level: Grid level for SCAN
        nstates: Number of states to compute
        conv_tol: Convergence tolerance
        checkpoint_id: Identifier for the checkpoint
        molecule_idx: Index of the molecule in the dataset

    Returns:
        Dictionary with results including energies, timings, and metadata
    """

    # Extract atom metadata
    atom_z = np.asarray(sys.atom_z)
    atom_mask = np.asarray(sys.atom_mask)
    unique_elements = sorted(set(atom_z[atom_mask]))
    elements_str = ','.join(map(str, unique_elements))
    n_atoms = int(np.sum(atom_mask))
    n_heavy_atoms = int(np.sum((atom_z[atom_mask] != 1)))
    chemical_formula = _atom_z_to_formula(atom_z[atom_mask])

    result = {
        'checkpoint_id': checkpoint_id,
        'molecule_idx': molecule_idx,
        'n_atoms': n_atoms,
        'n_heavy_atoms': n_heavy_atoms,
        'elements': elements_str,
        'chemical_formula': chemical_formula,
        'status': 'success',
        'pipeline': 'SCAN-repo',
    }

    # Build SCAN functional and XC module
    functional = get_functional(
        'scan', spin_restricted=True, use_density_fitting=use_density_fitting
    )
    xc_module = XCModule(functional, DensityFeatures(spin_restricted=True))

    mol = sys.to_pyscf(basis)
    dm_init = pyscf_scf.RKS(mol).get_init_guess()
    xc_params = xc_module.init(jax.random.PRNGKey(0), jnp.asarray(dm_init), sys.grid)

    # Run custom SCF with SCAN
    time_start = time.time()
    cycles = 15
    (
        P_ref,
        orbo,
        orbv,
        e_ia,
        hdiag,
        xc_params,
        total_e,
    ) = _scan_reference_from_custom_scf(
        sys=sys,
        basis=basis,
        cycles=cycles,
        xc_module=xc_module,
        use_density_fitting=use_density_fitting,
    )
    time_scf = time.time() - time_start
    result['time_scf_s'] = time_scf
    result['scf_final_energy_Ha'] = total_e
    result['scf_volatility'] = np.nan  # Not tracked for SCAN
    print(f'SCAN-repo SCF final E (Ha): {total_e}')
    print(f'Time taken for SCAN-repo SCF: {time_scf:.1f} seconds')

    if do_tda:
        # Davidson TDA
        time_start = time.time()
        tda_mv = build_cassida_mv(
            sys=sys,
            xc_module=xc_module,
            params=xc_params,
            occupied_orbs=jnp.asarray(orbo),
            virtual_orbs=jnp.asarray(orbv),
            e_ia=jnp.asarray(e_ia),
            P_ref=jnp.asarray(P_ref),
            spin_restricted=True,
            use_density_fitting=bool(use_density_fitting),
            tda_approx=True,
        )
        # print('Davidson with SCAN-repo linear response (TDA)')
        e_tda, _ = Davidson(
            lambda X: np.asarray(tda_mv(jnp.asarray(X))),
            hdiag,
            N_states=nstates,
            conv_tol=conv_tol,
        )
        time_tda = time.time() - time_start
        result['time_tda_s'] = time_tda
        result['tda_energies_eV'] = list(e_tda * Hartree_to_eV)
        print('  SCAN-repo TDA energies (eV):', e_tda * Hartree_to_eV)
        print(f'Time taken for TDA: {time_tda:.1f} seconds')

    # Davidson TDDFT
    def mv_xy(X: np.ndarray, Y: np.ndarray):
        U1, U2 = tddft_mv(jnp.asarray(X), jnp.asarray(Y))
        return np.asarray(U1), np.asarray(U2)

    # print('Davidson with SCAN-repo linear response (TDDFT)')
    time_start = time.time()
    tddft_mv = build_cassida_mv(
        sys=sys,
        xc_module=xc_module,
        params=xc_params,
        occupied_orbs=jnp.asarray(orbo),
        virtual_orbs=jnp.asarray(orbv),
        e_ia=jnp.asarray(e_ia),
        P_ref=jnp.asarray(P_ref),
        spin_restricted=True,
        use_density_fitting=bool(use_density_fitting),
        tda_approx=False,
    )
    e_tddft, _, _ = Davidson_Casida(mv_xy, hdiag, N_states=nstates, conv_tol=conv_tol)
    time_tddft = time.time() - time_start
    result['time_tddft_s'] = time_tddft
    result['tddft_energies_eV'] = list(e_tddft * Hartree_to_eV)
    print('  SCAN-repo TDDFT energies (eV):', e_tddft * Hartree_to_eV)
    print(f'Time taken for TDDFT: {time_tddft:.1f} seconds')

    return result


def _sample_molecules_by_heavy_atom_count(
    dataset: QM9,
    samples_per_heavy_atom: int = 10,
    seed: int = 42,
) -> Dict[int, list[int]]:
    """Sample molecules from QM9 dataset grouped by heavy atom count.

    Args:
        dataset: QM9 dataset
        samples_per_heavy_atom: Number of samples per heavy atom count
        seed: Random seed for sampling

    Returns:
        Dictionary mapping heavy atom count to list of dataset indices
    """
    np.random.seed(seed)

    # Group molecules by heavy atom count
    heavy_atom_groups = defaultdict(list)

    print('Grouping molecules by heavy atom count...')
    for idx in range(len(dataset)):
        _, (nuc_pos, atom_z, _, _, _), _ = dataset[idx]
        atom_z = np.asarray(atom_z)
        n_heavy = int(np.sum(atom_z != 1))
        heavy_atom_groups[n_heavy].append(idx)

    # Print statistics
    print('\nMolecules per heavy atom count:')
    for n_heavy in sorted(heavy_atom_groups.keys()):
        print(f'  {n_heavy} heavy atoms: {len(heavy_atom_groups[n_heavy])} molecules')

    # Sample from each group
    sampled_indices = {}
    print(f'\nSampling {samples_per_heavy_atom} molecules per heavy atom count...')
    for n_heavy, indices in sorted(heavy_atom_groups.items()):
        if len(indices) >= samples_per_heavy_atom:
            sampled = np.random.choice(indices, samples_per_heavy_atom, replace=False)
        else:
            print(
                f'  Warning: Only {len(indices)} molecules with {n_heavy} heavy atoms (requested {samples_per_heavy_atom})'
            )
            sampled = np.array(indices)
        sampled_indices[n_heavy] = sorted(sampled.tolist())
        print(f'  {n_heavy} heavy atoms: sampled {len(sampled)} molecules')

    return sampled_indices


def run_tddft_for_checkpoints(
    checkpoint_base_dir: str,
    nstates: int = 5,
    conv_tol: float = 1e-5,
    samples_per_heavy_atom: int = 10,
    seed: int = 42,
    data_dir: str = 'ANONYMOUS_DIR',
    output_base_dir: str = './results_tddft/size_extrapolation',
    use_density_fitting: bool = False,
    molecule: str | None = None,
    redo: bool = False,
    use_wandb: bool = False,
):
    """Run TDDFT calculations for size extrapolation studies.

    Samples molecules from QM9 (excluding fluorine) grouped by heavy atom count.

    Compare 3 pipelines:
    1) DEI-XC (repo SCF) + local TDDFT
    2) SCAN (repo SCF) + local TDDFT

    Args:
        checkpoint_base_dir: Base directory containing checkpoint subdirectories.
        nstates: Number of excitation energies to compute.
        conv_tol: Davidson residual tolerance.
        samples_per_heavy_atom: Number of molecules to sample per heavy atom count.
        seed: Random seed for sampling molecules.
        data_dir: Path to datasets directory.
        output_base_dir: Base directory for output CSV files.
        molecule: Optional molecule name (e.g. "water") to use instead of QM9 dataset.
        use_wandb: Log results to Weights & Biases.
    """
    if use_wandb:
        wandb.init(
            project='anonymous',
            config={
                'checkpoint_base_dir': checkpoint_base_dir,
                'nstates': nstates,
                'conv_tol': conv_tol,
                'samples_per_heavy_atom': samples_per_heavy_atom,
                'seed': seed,
                'use_density_fitting': use_density_fitting,
                'molecule': molecule,
            },
        )

    # get one or multiple checkpoints
    base_path = Path(checkpoint_base_dir)
    if base_path.is_file():
        checkpoints = [(base_path.parent.parent.stem, base_path.parent)]
    else:
        checkpoints = _discover_checkpoints(base_path)

    if not checkpoints:
        print(f'No checkpoints found in {checkpoint_base_dir}')
        return

    print(f'Found {len(checkpoints)} checkpoints')

    if molecule is not None:
        # Single molecule mode (e.g. water)
        print(f'Using example molecule: {molecule}')
        sample_indices_by_heavy_atom = {0: [0]}
        dataset = []
    else:
        # Load QM9 dataset (excluding fluorine, no heavy atom threshold)
        print('Loading QM9 dataset (exclude_fluorine=True, no heavy atom threshold)')
        dataset = QM9(
            data_dir=data_dir,
            heavy_atoms_thresh=100,  # No threshold - use full QM9
            exclude_fluorine=True,
            energy_unit='hartree',
            distance_unit='ang',
        )
        train_set, val_set, test_set = dataset.random_split(val_fraction=0.0, seed=0)
        dataset = train_set

        total_samples = len(dataset)
        print(f'Total QM9 samples: {total_samples}')

        # Sample molecules grouped by heavy atom count
        sample_indices_by_heavy_atom = _sample_molecules_by_heavy_atom_count(
            dataset, samples_per_heavy_atom, seed
        )

    for ckpt_count, (checkpoint_id, checkpoint_path) in enumerate(checkpoints, 1):
        print(f'\n{"=" * 80}')
        print(f'[{ckpt_count}/{len(checkpoints)}] Processing checkpoint: {checkpoint_id}')
        print(f'{"=" * 80}')

        output_dir = Path(output_base_dir) / checkpoint_id
        output_path = output_dir / f'{checkpoint_id}_tddft_results.csv'
        if output_path.exists() and not redo:
            print(f'Results already exist for checkpoint {checkpoint_id}, skipping.')
            try:
                print_mae_statistics(output_path)
            except Exception as e:
                print(f'Error printing statistics: {e}')
            continue

        cfg, cfg_path = _load_checkpoint_config(checkpoint_path)
        params_solver = unpickle_dictionary(
            str(checkpoint_path / 'best_dynamic_train_params.flax')
        )
        xc_params = _extract_xc_module_params(params_solver)

        # Extract metadata from config for CSV
        cfg_relative_path = str(cfg_path.resolve().relative_to(Path.cwd()))
        static_loss_weights = cfg['static_loss']['relative_weights']
        dynamic_loss_weights = cfg['dynamic_loss']['relative_weights']
        cfg_metadata = {
            'config_path': cfg_relative_path,
            'static_loss.relative_weights.forces': static_loss_weights['forces'],
            'static_loss.relative_weights.orbital_rotation_gradient': static_loss_weights[
                'orbital_rotation_gradient'
            ],
            'static_loss.relative_weights.xc_energy': static_loss_weights['xc_energy'],
            'static_loss.relative_weights.xc_potential': static_loss_weights[
                'xc_potential'
            ],
            'dynamic_loss.relative_weights.density': dynamic_loss_weights['density'],
            'dynamic_loss.relative_weights.forces': dynamic_loss_weights['forces'],
            'dynamic_loss.relative_weights.orbital_rotation_gradient': dynamic_loss_weights[
                'orbital_rotation_gradient'
            ],
            'dynamic_loss.relative_weights.orbital_rotation_hessian': dynamic_loss_weights[
                'orbital_rotation_hessian'
            ],
            'dynamic_loss.relative_weights.total_energy': dynamic_loss_weights[
                'total_energy'
            ],
            'dynamic_loss.relative_weights.xc_energy': dynamic_loss_weights['xc_energy'],
            'dynamic_loss.relative_weights.xc_potential': dynamic_loss_weights[
                'xc_potential'
            ],
            'base.seed': cfg['base']['seed'],
            'run_seed': cfg['run_seed'],
            'seed': cfg['seed'],
        }

        # Get system configuration
        basis = cfg['basis']['name']
        grid_level = int(cfg['quadrature']['level'])
        a = cfg['alignment']
        alignment = Alignment(int(a['atom']), int(a['basis']), int(a['grid']))

        results = []

        # Process molecules grouped by heavy atom count
        total_molecules = sum(
            len(indices) for indices in sample_indices_by_heavy_atom.values()
        )
        mol_counter = 0

        for n_heavy in sorted(sample_indices_by_heavy_atom.keys()):
            indices = sample_indices_by_heavy_atom[n_heavy]
            print(f'\n{"#" * 80}')
            print(f'Processing {len(indices)} molecules with {n_heavy} heavy atoms')
            print(f'{"#" * 80}')

            for mol_idx_in_group, idx in enumerate(indices, 1):
                mol_counter += 1
                print(f'\n{"-" * 60}')
                print(
                    f'[{mol_counter}/{total_molecules}] Molecule {idx} ({n_heavy} heavy atoms, {mol_idx_in_group}/{len(indices)})'
                )
                print(f'{"-" * 60}')

                if molecule is not None:
                    sys = examples.get(
                        molecule,
                        basis=basis,
                        alignment=alignment,
                        use_density_fitting=bool(use_density_fitting),
                        spin_restricted=True,
                        include_grid=True,
                        grid_level=grid_level,
                    )
                else:
                    # Load sample
                    _, (nuc_pos, atom_z, _, _, _), _ = dataset[idx]
                    nuc_pos = np.asarray(nuc_pos)
                    atom_z = np.asarray(atom_z)

                    print(
                        f'Atoms: {len(atom_z)}, Elements: {np.unique(atom_z)}, Formula: {_atom_z_to_formula(atom_z)}'
                    )

                    # Build system
                    sys = _build_system_from_sample(
                        nuc_pos,
                        atom_z,
                        basis,
                        grid_level,
                        alignment,
                        use_density_fitting=use_density_fitting,
                    )
                    n_electrons = int(sys.n_electrons)
                    print(f'Electrons: {n_electrons}, Basis: {basis}')

                # Pipeline 1: DEI-XC
                print('\n1) DEI-XC (repo SCF) + local TDDFT')
                result_deixc = _run_tddft_for_molecule(
                    sys,
                    basis,
                    cfg,
                    params_solver,
                    xc_params,
                    nstates,
                    conv_tol,
                    checkpoint_id=checkpoint_id,
                    molecule_idx=idx,
                    use_density_fitting=use_density_fitting,
                )
                result_deixc.update(cfg_metadata)
                results.append(result_deixc)
                log_dict = {f'deixc/{k}': v for k, v in result_deixc.items()}

                # Pipeline 2: SCAN-repo
                print('\n2) SCAN (repo SCF) + local TDDFT')
                result_scan_repo = _run_scan_repo_for_molecule(
                    sys,
                    basis,
                    grid_level,
                    nstates,
                    conv_tol,
                    checkpoint_id=checkpoint_id,
                    molecule_idx=idx,
                    use_density_fitting=use_density_fitting,
                )
                result_scan_repo.update(cfg_metadata)
                results.append(result_scan_repo)
                log_dict.update(
                    {f'scan_repo/{k}': v for k, v in result_scan_repo.items()}
                )

                # Log all pipelines for this molecule as one wandb step
                if use_wandb:
                    step = (ckpt_count - 1) * total_molecules + mol_counter - 1
                    wandb.log(log_dict)

        # Save results for this checkpoint
        _save_results_to_csv(results, output_path)

        try:
            print_mae_statistics(output_path)
        except Exception as e:
            traceback.print_exc()
        print(f'\nCheckpoint {checkpoint_id} complete')

    # Aggregate all checkpoint CSVs into one combined file
    output_base = Path(output_base_dir)
    all_csvs = list(output_base.glob('*/*_tddft_results.csv'))
    if all_csvs:
        combined_df = pd.concat([pd.read_csv(f) for f in all_csvs], ignore_index=True)
        combined_path = output_base / 'all_checkpoints_combined.csv'
        combined_df.to_csv(combined_path, index=False)
        print(
            f'\nSaved combined results ({len(all_csvs)} checkpoints) to {combined_path}'
        )

        if use_wandb:
            wandb.save(str(combined_path))
            print(f'Logged combined CSV to wandb: {combined_path}')

    print('\nAll checkpoints complete')
    return


def print_mae_statistics(csv_path: Path):
    """Print MAE statistics for the given CSV file."""
    if not csv_path.exists():
        print(f'No results found for {csv_path}')
        return

    df = pd.read_csv(csv_path)

    print('df length', len(df))

    # Pivot to get energies by pipeline for comparison
    perpipeline: pd.DataFrame = df.pivot_table(
        index=['checkpoint_id', 'molecule_idx', 'state_index', 'method'],
        columns='pipeline',
        values='energy_eV',
    ).reset_index()
    print('perpipeline length', len(perpipeline))

    # Ground state energy differences
    scf_pivot: pd.DataFrame = df.drop_duplicates(
        subset=['checkpoint_id', 'molecule_idx', 'pipeline']
    ).pivot_table(
        index=['checkpoint_id', 'molecule_idx'],
        columns='pipeline',
        values='scf_final_energy_Ha',
    )
    print('Ground state energy scf_pivot length', len(scf_pivot))
    # if 'DEI-XC' in scf_pivot.columns:
    print('\nGround state energy differences (Ha):')
    if 'SCAN-repo' in scf_pivot.columns:
        diff_repo = scf_pivot['DEI-XC'] - scf_pivot['SCAN-repo']
        print(
            f'  DEI-XC - SCAN-repo: MAE={np.abs(diff_repo).mean():.2e}, RelMAE={np.abs(diff_repo / scf_pivot["SCAN-repo"]).mean() * 100:.2f}%'
        )

    print('\nExcitation energy differences (eV):')

    # MAE between DEI-XC and SCAN-repo for TDDFT
    tddft = perpipeline[perpipeline['method'] == 'TDDFT'].dropna(
        subset=['DEI-XC', 'SCAN-repo']
    )
    print('tddft length', len(tddft))
    print('tddft DEI-XC length', len(tddft[tddft['DEI-XC'].notna()]))
    print('tddft SCAN-repo length', len(tddft[tddft['SCAN-repo'].notna()]))
    if len(tddft) > 0:
        tddft_diff = np.abs(tddft['DEI-XC'] - tddft['SCAN-repo'])
        tddft_rel_diff = tddft_diff / np.abs(tddft['SCAN-repo'])
        print('\nDEI-XC vs SCAN-repo (TDDFT):')
        for state_idx in sorted(tddft['state_index'].unique()):
            state_mask = tddft['state_index'] == state_idx
            mae = tddft_diff[state_mask].mean()
            relmae = tddft_rel_diff[state_mask].mean() * 100
            print(f'  State {state_idx}: MAE = {mae:.2e} eV, RelMAE = {relmae:.2f}%')
        print(
            f'  All states: MAE = {tddft_diff.mean():.2e} eV, RelMAE = {tddft_rel_diff.mean() * 100:.2f}%'
        )

    return


def _parse_args():
    """Parse CLI args for the DEI-XC TDDFT size extrapolation eval."""
    p = argparse.ArgumentParser()
    p.add_argument(
        '--checkpoint-base-dir',
        type=str,
        default='./evaluations/1000/XCdiff_scan_qm5',
        help='Base directory containing checkpoint subdirectories',
    )
    p.add_argument('--nstates', type=int, default=5)
    p.add_argument('--conv-tol', type=float, default=1e-5)
    p.add_argument(
        '--samples-per-heavy-atom',
        type=int,
        default=10,
        help='Number of molecules to sample per heavy atom count',
    )
    p.add_argument(
        '--seed',
        type=int,
        default=42,
        help='Random seed for molecule sampling',
    )
    p.add_argument(
        '--data-dir',
        type=str,
        default='../Datastore/',
        help='Path to datasets directory',
    )
    p.add_argument(
        '--output-base-dir',
        type=str,
        default='./results_tddft',
        help='Base directory for output CSV files',
    )
    p.add_argument(
        '--redo',
        action='store_true',
        help='Redo existing results.',
    )
    p.add_argument(
        '--molecule',
        type=str,
        default=None,
        help='Optional molecule name (e.g. "water") to use instead of QM9 dataset.',
    )
    p.add_argument(
        '--wandb',
        action='store_true',
        help='Log results to Weights & Biases.',
    )
    return p.parse_args()


"""
Example usage:

uv run scripts/run_tddft_sizeextrapolation.py --checkpoint-base-dir evaluations/tddft/b3lyp --nstates 5 --samples-per-heavy-atom 10 --output-base-dir ./results_tddft/size_extrapolation

This will sample 10 molecules for each heavy atom count (1-9) from QM9 (excluding fluorine).
"""
if __name__ == '__main__':
    args = _parse_args()

    # set a directory name for the output
    dir_name = ''
    if Path(args.checkpoint_base_dir).is_file():
        dir_name += '/' + str(Path(args.checkpoint_base_dir).parent.parent.stem)
    else:
        dir_name += '/' + args.checkpoint_base_dir.split('/')[-1]
    dir_name += '/'
    if args.molecule is not None:
        dir_name += args.molecule
    else:
        dir_name += 'QM9_size_extrapolation'
    dir_name += '_n' + str(args.nstates)
    dir_name += '_s' + str(args.samples_per_heavy_atom)
    output_base_dir = args.output_base_dir + dir_name

    run_tddft_for_checkpoints(
        checkpoint_base_dir=args.checkpoint_base_dir,
        nstates=args.nstates,
        conv_tol=args.conv_tol,
        samples_per_heavy_atom=args.samples_per_heavy_atom,
        seed=args.seed,
        data_dir=args.data_dir,
        output_base_dir=output_base_dir,
        use_density_fitting=True,
        molecule=args.molecule,
        redo=args.redo,
        use_wandb=args.wandb,
    )
