#!/usr/bin/env python3
"""Compare two TDDFT/TDA implementations on a simple water example (SCAN).

Compares:
- Library implementation in `src/tddft/response.py`
- Demo implementation embedded in `scripts/run_tddft_deixc.py`

We compare (at the same reference point):
- Total energy E(P_ref)
- Batched AO potential response vresp(dP)
- TDA matrix-vector product
- Full TDDFT (Casida) matrix-vector product
"""

from __future__ import annotations

import os
import sys
from dataclasses import dataclass
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np
from pyscf import dft

# Keep behavior consistent with the demo script:
os.environ.setdefault("JAX_PLATFORM_NAME", "cpu")
jax.config.update("jax_enable_x64", True)

# Ensure repo root is importable so `scripts.run_tddft_deixc` can be imported when
# this file is executed as `python scripts/...py`.
_REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(_REPO_ROOT))

from egxc.systems import examples
from egxc.utils.typing import Alignment
from egxc.xc_energy import DensityFeatures
from egxc.xc_energy.functionals import get_functional
from egxc.xc_energy.xc_module import XCModule
from tddft import response as lib_response

# Import the demo builders (re-implementation).
from scripts import run_tddft_deixc as demo_impl


@dataclass(frozen=True)
class Settings:
    basis: str = "sto-3g"
    grid_level: int = 1
    alignment: Alignment = Alignment()
    seed: int = 0
    t_resp: int = 3  # batch size for vresp
    m: int = 2  # block size for MV products


def _relerr(a: np.ndarray, b: np.ndarray, *, eps: float = 1e-14) -> float:
    na = float(np.linalg.norm(a.ravel()))
    nb = float(np.linalg.norm(b.ravel()))
    denom = max(na, nb, eps)
    return float(np.linalg.norm((a - b).ravel()) / denom)


def _symmetrize(a: np.ndarray) -> np.ndarray:
    return 0.5 * (a + a.T)


def main() -> None:
    s = Settings()
    rng = np.random.default_rng(s.seed)

    # Build an EGXC System with full 4-index ERIs (required by demo implementation).
    sys = examples.get(
        "water",
        basis=s.basis,
        alignment=s.alignment,
        use_density_fitting=False,
        spin_restricted=True,
        include_grid=True,
        grid_level=int(s.grid_level),
    )

    # Build a PySCF SCAN reference on the same `mol` (same basis/alignment).
    mol = sys.to_pyscf(s.basis)
    mf = dft.RKS(mol, xc="scan")
    mf.kernel()

    C = np.asarray(mf.mo_coeff)
    eps = np.asarray(mf.mo_energy)
    occ = np.asarray(mf.mo_occ)
    occidx = np.where(occ > 0)[0]
    viridx = np.where(occ == 0)[0]
    orbo, orbv = C[:, occidx], C[:, viridx]
    e_ia = eps[viridx] - eps[occidx, None]  # (nocc, nvir)
    P_ref = np.asarray(mf.make_rdm1())

    # SCAN XC module (no trainable params, but still goes through XCModule interface).
    functional = get_functional("scan", spin_restricted=True, use_density_fitting=False)
    xc_module = XCModule(functional, DensityFeatures(spin_restricted=True))
    xc_params = xc_module.init(jax.random.PRNGKey(0), jnp.asarray(P_ref), sys.grid)

    # --- Build operators ---
    lib_E, lib_vresp = lib_response.build_total_energy_and_vresp(
        sys=sys,
        xc_module=xc_module,
        params=xc_params,
        P_ref=jnp.asarray(P_ref),
        spin_restricted=True,
        use_density_fitting=False,
    )
    demo_E, demo_vresp = demo_impl.build_total_energy_and_vresp(
        sys, xc_module, xc_params, jnp.asarray(P_ref)
    )

    lib_tda_mv = lib_response.build_cassida_mv(
        sys=sys,
        xc_module=xc_module,
        params=xc_params,
        occupied_orbs=jnp.asarray(orbo),
        virtual_orbs=jnp.asarray(orbv),
        e_ia=jnp.asarray(e_ia),
        P_ref=jnp.asarray(P_ref),
        spin_restricted=True,
        use_density_fitting=False,
        tda_approx=True,
    )
    lib_tddft_mv = lib_response.build_cassida_mv(
        sys=sys,
        xc_module=xc_module,
        params=xc_params,
        occupied_orbs=jnp.asarray(orbo),
        virtual_orbs=jnp.asarray(orbv),
        e_ia=jnp.asarray(e_ia),
        P_ref=jnp.asarray(P_ref),
        spin_restricted=True,
        use_density_fitting=False,
        tda_approx=False,
    )

    demo_tda_mv = demo_impl.build_cassida_mv(
        sys,
        xc_module,
        xc_params,
        jnp.asarray(orbo),
        jnp.asarray(orbv),
        jnp.asarray(e_ia),
        jnp.asarray(P_ref),
        tda_approx=True,
    )
    demo_tddft_mv = demo_impl.build_cassida_mv(
        sys,
        xc_module,
        xc_params,
        jnp.asarray(orbo),
        jnp.asarray(orbv),
        jnp.asarray(e_ia),
        jnp.asarray(P_ref),
        tda_approx=False,
    )

    # --- Compare ---
    print("=== settings ===")
    print(f"basis={s.basis} grid_level={s.grid_level} alignment={s.alignment} seed={s.seed}")
    print(f"nao={P_ref.shape[0]} nocc={orbo.shape[1]} nvir={orbv.shape[1]}")

    E_lib = float(np.asarray(lib_E(jnp.asarray(P_ref))))
    E_demo = float(np.asarray(demo_E(jnp.asarray(P_ref))))
    print("\n=== energy_total(P_ref) ===")
    print(f"E_lib  = {E_lib:.16e}")
    print(f"E_demo = {E_demo:.16e}")
    print(f"absdiff = {abs(E_lib - E_demo):.3e}")
    print(f"relerr  = {_relerr(np.array([E_lib]), np.array([E_demo])):.3e}")

    # vresp comparison on symmetric random perturbations
    T = int(s.t_resp)
    dP = rng.standard_normal((T, P_ref.shape[0], P_ref.shape[1]))
    dP = np.stack([_symmetrize(dP[i]) for i in range(T)], axis=0)
    V_lib = np.asarray(lib_vresp(jnp.asarray(dP)))
    V_demo = np.asarray(demo_vresp(jnp.asarray(dP)))
    print("\n=== vresp(dP_batch) ===")
    print(f"relerr = {_relerr(V_lib, V_demo):.3e}")

    # TDA MV comparison
    OV = int(orbo.shape[1] * orbv.shape[1])
    M = int(s.m)
    X = rng.standard_normal((OV, M))
    AX_lib = np.asarray(lib_tda_mv(jnp.asarray(X)))
    AX_demo = np.asarray(demo_tda_mv(jnp.asarray(X)))
    print("\n=== TDA MV ===")
    print(f"relerr = {_relerr(AX_lib, AX_demo):.3e}")

    # Full TDDFT MV comparison
    Y = rng.standard_normal((OV, M))
    U1_lib, U2_lib = lib_tddft_mv(jnp.asarray(X), jnp.asarray(Y))
    U1_demo, U2_demo = demo_tddft_mv(jnp.asarray(X), jnp.asarray(Y))
    U1_lib = np.asarray(U1_lib)
    U2_lib = np.asarray(U2_lib)
    U1_demo = np.asarray(U1_demo)
    U2_demo = np.asarray(U2_demo)
    print("\n=== TDDFT MV (Casida) ===")
    print(f"U1 relerr = {_relerr(U1_lib, U1_demo):.3e}")
    print(f"U2 relerr = {_relerr(U2_lib, U2_demo):.3e}")


if __name__ == "__main__":
    main()


