"""
TDDFT/TDA builders using the shared response function and Davidson diagonalization solver for TDA and TDDFT.
Adapted from https://github.com/John-zzh/TDDFT_Davidson
Supposedly more numerically stable than the pyscf implementation. Not sure if true.

build_cassida_mv(..., tda_approx=False): returns a callable that computes
the Casida (TDDFT) matrix-vector product (AX+BY, BX+AY). If tda_approx is
True, returns the TDA operator that maps X -> AX.
"""

from __future__ import annotations
import time
from re import L
from typing import Any, Callable, Tuple, Dict

import os
from pathlib import Path


import jax
import jax.numpy as jnp
import numpy as np
# import pandas as pd
# import yaml

from deixc.scf import DerivativeInformedSelfConsistentFieldSolver
from egxc.dataloading import QM9
from egxc.dataloading.io import unpickle_dictionary
from egxc.discretization import (GTOBasis, get_grid_fn, get_gto_grid_eval_fn,
                                 get_gto_preloader)
from egxc.systems import Grid, System, examples, nuclear_energy_fn
from egxc.systems.preload import preload_system_using_pyscf
from egxc.utils.linalg import modified_generalized_eigenvalue_problem
from egxc.utils.typing import Alignment, cast_to_integer_tuple
from egxc.xc_energy import DensityFeatures
from egxc.xc_energy.functionals import get_functional
from egxc.xc_energy.functionals.classical import (BaseRangeSeparatedHybrid,
                                                  Hybrid)
from egxc.xc_energy.xc_module import XCModule


@jax.jit
def _hartree_energy(P: jax.Array, eri4: jax.Array) -> jax.Array:
    """Restricted Hartree energy from AO density matrix and 4-index ERIs."""
    return jnp.asarray(0.5) * jnp.einsum('ij,kl,ijkl->', P, P, eri4)


def _non_local_kwargs(sys, functional, eri4: jax.Array) -> Dict[str, Any]:
    """Build extra keyword args required by non-local/hybrid functionals."""
    kwargs: Dict[str, Any] = {}
    if isinstance(functional, (Hybrid, BaseRangeSeparatedHybrid)):
        kwargs['eri_tensor'] = eri4
    if getattr(functional, 'is_graph_based', False):
        kwargs['atom_mask'] = sys.atom_mask
        kwargs['nuc_pos'] = sys._nuc_pos
        kwargs['grid_coords'] = sys.grid.coords
    return kwargs


def build_total_energy_and_vresp(
    sys, xc_module: XCModule, xc_params: Any, P_ref: jax.Array
):
    """Build total-energy function and AO potential linear response around `P_ref`.

    The linear response is computed as an HVP/JVP:
        dV = (d/dP) grad_P(E_xc(P) + E_H(P)) [dP]

    Args:
        sys: EGXC `System` providing grid and ERIs.
        xc_module: Trainable XC energy module.
        xc_params: Parameters for `xc_module`.
        P_ref: Reference AO density matrix for linearization.

    Returns:
        (energy_total, vresp)
        - energy_total(P): scalar total energy.
        - vresp(dms): batched AO potential responses for AO density perturbations.
    """
    eri = jnp.asarray(sys.fock_tensors.electron_repulsion_tensor)
    if eri.ndim != 4:
        raise ValueError('Need 4-index ERIs: use_density_fitting=False')
    P0 = jnp.asarray(P_ref)
    non_local = _non_local_kwargs(sys, xc_module.functional, eri)

    @jax.jit
    def energy_total(P: jax.Array) -> jax.Array:
        exc = xc_module.apply(xc_params, P, sys.grid, **non_local)
        return exc + _hartree_energy(P, eri)

    def hv(dP: jax.Array) -> jax.Array:
        return jax.jvp(jax.grad(energy_total), (P0,), (dP,))[1]

    return energy_total, jax.jit(lambda dms: jax.vmap(hv, in_axes=0)(dms))


# def build_cassida_mv_nodensityfitting(
#     sys, xc_module: XCModule, xc_params: Any, orbo, orbv, e_ia, P_ref, *, tda_approx: bool
# ):
#     """Build Casida/TDA matrix-vector products in the particle-hole (i->a) space.

#     Args:
#         sys: EGXC `System`.
#         xc_module: Trainable XC module.
#         xc_params: Parameters for `xc_module`.
#         orbo: Occupied MO coefficients, shape (nao, nocc).
#         orbv: Virtual MO coefficients, shape (nao, nvir).
#         e_ia: Orbital energy gaps eps_a - eps_i, shape (nocc, nvir).
#         P_ref: AO density matrix reference point.
#         tda_approx: If True, return TDA MV `mv(X)`. Else return TDDFT MV `mv(X,Y)`.

#     Returns:
#         Callable implementing the requested MV product(s).
#     """
#     _, vresp = build_total_energy_and_vresp(sys, xc_module, xc_params, P_ref)
#     nocc, nvir = orbo.shape[1], orbv.shape[1]

#     @jax.jit
#     def tda_mv(X):
#         m = X.shape[1]
#         xs = X.T.reshape(m, nocc, nvir)
#         dms = jnp.einsum('xov,pv,qo->xpq', xs, orbv, orbo.conj() * 2.0)
#         v1ao = vresp(dms)
#         v1mo = jnp.einsum('xpq,qo,pv->xov', v1ao, orbo, orbv.conj())
#         v1mo = v1mo + jnp.einsum('xia,ia->xia', xs, e_ia)
#         return v1mo.reshape(m, -1).T

#     if tda_approx:
#         return tda_mv

#     @jax.jit
#     def tddft_mv(X, Y):
#         m = X.shape[1]
#         xs = X.T.reshape(m, nocc, nvir)
#         ys = Y.T.reshape(m, nocc, nvir)
#         dms = jnp.einsum('xov,pv,qo->xpq', xs, orbv, orbo.conj() * 2.0)
#         dms = dms + jnp.einsum('xov,qv,po->xpq', ys, orbv.conj(), orbo * 2.0)
#         v1ao = vresp(dms)
#         top = jnp.einsum('xpq,qo,pv->xov', v1ao, orbo, orbv.conj())
#         bot = jnp.einsum('xpq,po,qv->xov', v1ao, orbo.conj(), orbv)
#         top = top + jnp.einsum('xia,ia->xia', xs, e_ia)
#         bot = bot + jnp.einsum('xia,ia->xia', ys, e_ia)
#         return top.reshape(m, -1).T, bot.reshape(m, -1).T

#     return tddft_mv


def _utriangle_sym(A: np.ndarray) -> np.ndarray:
    u = np.triu_indices(n=A.shape[0], k=1)
    A[(u[1], u[0])] = A[u]
    return A


def _tda_init_guess(V: np.ndarray, n: int, hdiag: np.ndarray) -> np.ndarray:
    idx = np.argsort(hdiag.reshape(-1))[:n]
    V[idx, np.arange(n)] = 1.0
    return V


def _tda_precond(res: np.ndarray, w: np.ndarray, hdiag: np.ndarray) -> np.ndarray:
    t = 1e-8
    D = hdiag.reshape(-1, 1) - w
    D = np.where(np.abs(D) < t, np.sign(D) * t, D)
    return res / D


def _tddft_precond(Rx: np.ndarray, Ry: np.ndarray, w: np.ndarray, hdiag: np.ndarray):
    t = 1e-14
    d = np.repeat(hdiag.reshape(-1, 1), Rx.shape[1], axis=1)
    Dx = np.where(np.abs(d - w) < t, np.sign(d - w) * t, d - w)
    Dy = np.where(np.abs(d + w) < t, np.sign(d + w) * t, d + w)
    return Rx / Dx, Ry / Dy


def _gs_bvec(A: np.ndarray, b: np.ndarray) -> np.ndarray:
    return b if A.shape[1] == 0 else b - A @ (A.T @ b)


def _gs_fill(V: np.ndarray, count: int, vecs: np.ndarray):
    for j in range(vecs.shape[1]):
        v = vecs[:, j : j + 1]
        v = _gs_bvec(V[:, :count], v)
        v = _gs_bvec(V[:, :count], v)
        nrm = np.linalg.norm(v)
        if nrm > 1e-14:
            V[:, count] = (v / nrm)[:, 0]
            count += 1
    return V, count


def _vw_gs(x: np.ndarray, y: np.ndarray, V: np.ndarray, W: np.ndarray):
    m = V.T @ x + W.T @ y
    n = W.T @ x + V.T @ y
    return x - V @ m - W @ n, y - W @ m - V @ n


def _s_sym_ortho(x: np.ndarray, y: np.ndarray):
    xp, xm = x + y, x - y
    a = np.linalg.norm(xp) / np.linalg.norm(xm)
    xp, xm = xp / 2, xm * (a / 2)
    return xp + xm, xp - xm


def _vw_fill(Vh: np.ndarray, Wh: np.ndarray, m: int, Xn: np.ndarray, Yn: np.ndarray):
    for j in range(Xn.shape[1]):
        x = Xn[:, j : j + 1]
        y = Yn[:, j : j + 1]
        x, y = _vw_gs(x, y, Vh[:, :m], Wh[:, :m])
        x, y = _s_sym_ortho(x, y)
        nrm = float(np.sqrt((x.T @ x + y.T @ y)[0, 0]))
        if nrm > 1e-14:
            Vh[:, m] = (x / nrm)[:, 0]
            Wh[:, m] = (y / nrm)[:, 0]
            m += 1
    return Vh, Wh, m


def _gen_VW(
    sub: np.ndarray, Vh: np.ndarray, Wh: np.ndarray, old: int, new: int, sym: bool
):
    Vc = Vh[:, :new]
    Wn = Wh[:, old:new]
    sub[:new, old:new] = Vc.T @ Wn
    sub[old:new, :old] = sub[:old, old:new].T if sym else (Vh[:, old:new].T @ Wh[:, :old])
    return sub


def _matrix_power(S: np.ndarray, a: float) -> np.ndarray:
    s, U = np.linalg.eigh(S)
    return (U * (s**a)) @ U.T


def _gen_sub_ab(
    Vh: np.ndarray,
    Wh: np.ndarray,
    U1h: np.ndarray,
    U2h: np.ndarray,
    VU1: np.ndarray,
    WU2: np.ndarray,
    VU2: np.ndarray,
    WU1: np.ndarray,
    VV: np.ndarray,
    WW: np.ndarray,
    VW: np.ndarray,
    old: int,
    new: int,
):
    VU1 = _gen_VW(VU1, Vh, U1h, old, new, False)
    VU2 = _gen_VW(VU2, Vh, U2h, old, new, False)
    WU1 = _gen_VW(WU1, Wh, U1h, old, new, False)
    WU2 = _gen_VW(WU2, Wh, U2h, old, new, False)
    VV = _gen_VW(VV, Vh, Vh, old, new, False)
    WW = _gen_VW(WW, Wh, Wh, old, new, False)
    VW = _gen_VW(VW, Vh, Wh, old, new, False)
    A = _utriangle_sym(VU1[:new, :new] + WU2[:new, :new])
    B = _utriangle_sym(VU2[:new, :new] + WU1[:new, :new])
    sigma = _utriangle_sym(VV[:new, :new] - WW[:new, :new])
    pi = VW[:new, :new] - VW[:new, :new].T
    return A, B, sigma, pi, VU1, WU2, VU2, WU1, VV, WW, VW


def _tddft_sub_eigh(
    a: np.ndarray, b: np.ndarray, sigma: np.ndarray, pi: np.ndarray, k: int
):
    n = a.shape[0]
    A = np.block([[a, b], [b, a]])
    B = np.block([[sigma, pi], [-pi, -sigma]])
    A_mh = _matrix_power(A, -0.5)
    M = A_mh @ B @ A_mh
    w, Z = np.linalg.eigh(M)
    w = 1 / w[-k:][::-1]
    Z = Z[:, -k:][:, ::-1] * (w**0.5)
    T = A_mh @ Z
    return w, T[:n, :], T[n:, :]


def Davidson(
    mv, hdiag: np.ndarray, N_states: int = 5, conv_tol: float = 1e-5, max_iter: int = 40
):
    """Block-Davidson for symmetric eigenproblems (TDA-style).

    Args:
        mv: Matrix-vector product `mv(X) -> AX`, with X shaped (dim, block_size).
        hdiag: Diagonal approximation (used for preconditioning), shape (dim,).
        N_states: Number of lowest roots to compute.
        conv_tol: Residual norm tolerance.
        max_iter: Maximum Davidson iterations.

    Returns:
        (evals, X): Ritz values and approximate eigenvectors.
    """
    hdiag = hdiag.reshape(-1)
    n = hdiag.shape[0]
    size_old, size_new = 0, min([N_states + 8, 2 * N_states, n])
    block_size = size_new
    V = np.zeros((n, max_iter * N_states + size_new))
    W = np.zeros_like(V)
    sub = np.zeros((V.shape[1], V.shape[1]))
    V = _tda_init_guess(V, size_new, hdiag)
    for _ in range(max_iter):
        probe = V[:, size_old:size_new]
        pad = np.zeros((n, block_size))
        pad[:, : probe.shape[1]] = probe
        Wp = mv(pad)
        W[:, size_old:size_new] = Wp[:, : probe.shape[1]]
        sub = _gen_VW(sub, V, W, size_old, size_new, True)
        evals, evecs = np.linalg.eigh(_utriangle_sym(sub[:size_new, :size_new]))
        evals, evecs = evals[:N_states], evecs[:, :N_states]
        X = V[:, :size_new] @ evecs
        AX = W[:, :size_new] @ evecs
        R = AX - X * evals
        r = np.linalg.norm(R, axis=0)
        if float(np.max(r)) < conv_tol:
            return evals, X
        idx = np.where(r > conv_tol)[0]
        Q = _tda_precond(R[:, idx], evals[idx], hdiag)
        size_old = size_new
        V, size_new = _gs_fill(V, size_old, Q)
        if size_new == size_old:
            break
    return evals, X


def Davidson_Casida(
    mv, hdiag: np.ndarray, N_states: int = 5, conv_tol: float = 1e-5, max_iter: int = 60
):
    """Davidson-like solver for the full TDDFT (Casida) eigenproblem.

    The operator is accessed through a coupled MV product:
        mv(X, Y) -> (U1, U2)

    Args:
        mv: Coupled matrix-vector product.
        hdiag: Orbital energy gaps (diagonal preconditioner), shape (dim,).
        N_states: Number of roots.
        conv_tol: Convergence tolerance.
        max_iter: Maximum iterations.

    Returns:
        (w, X, Y): Excitation energies and corresponding eigenvectors.
    """
    hdiag = hdiag.reshape(-1)
    n = hdiag.shape[0]
    size_old, size_new = 0, min([N_states + 8, 2 * N_states, n])
    block_size = size_new
    max_mv = (max_iter + 1) * N_states
    V = np.zeros((n, max_mv))
    W = np.zeros_like(V)
    U1 = np.zeros_like(V)
    U2 = np.zeros_like(V)
    VU1 = np.zeros((max_mv, max_mv))
    VU2 = np.zeros_like(VU1)
    WU1 = np.zeros_like(VU1)
    WU2 = np.zeros_like(VU1)
    VV = np.zeros_like(VU1)
    VW = np.zeros_like(VU1)
    WW = np.zeros_like(VU1)
    V = _tda_init_guess(V, size_new, hdiag)
    for _ in range(max_iter):
        pV, pW = V[:, size_old:size_new], W[:, size_old:size_new]
        padV, padW = np.zeros((n, block_size)), np.zeros((n, block_size))
        padV[:, : pV.shape[1]] = pV
        padW[:, : pW.shape[1]] = pW
        U1t, U2t = mv(padV, padW)
        U1[:, size_old:size_new] = U1t[:, : pV.shape[1]]
        U2[:, size_old:size_new] = U2t[:, : pV.shape[1]]
        (a, b, sigma, pi, VU1, WU2, VU2, WU1, VV, WW, VW) = _gen_sub_ab(
            V, W, U1, U2, VU1, WU2, VU2, WU1, VV, WW, VW, size_old, size_new
        )
        w, x, y = _tddft_sub_eigh(a, b, sigma, pi, N_states)
        Vc, Wc = V[:, :size_new], W[:, :size_new]
        X = Vc @ x + Wc @ y
        Y = Wc @ x + Vc @ y
        U1c, U2c = U1[:, :size_new], U2[:, :size_new]
        Rx = U1c @ x + U2c @ y - X * w
        Ry = U2c @ x + U1c @ y + Y * w
        r = np.linalg.norm(np.vstack((Rx, Ry)), axis=0)
        if float(np.max(r)) < conv_tol:
            return w, X, Y
        idx = np.where(r > conv_tol)[0]
        Xn, Yn = _tddft_precond(Rx[:, idx], Ry[:, idx], w[idx], hdiag)
        size_old = size_new
        V, W, size_new = _vw_fill(V, W, size_old, Xn, Yn)
        if size_new == size_old:
            break
    return w, X, Y
