from functools import partial

import einops
import jax
import jax.numpy as jnp
from jax.scipy.linalg import expm

from egxc.utils.typing import FloatBxB, FloatOxV, FloatSCFxBxB, UInt1


@partial(jax.jit, static_argnames=['n_occ'])
def mo_to_ao(
    mo_tensor: FloatOxV,
    C: FloatBxB,
    S: FloatBxB,
    n_occ: int,
    symmetrize: bool = True,
) -> FloatBxB:
    """
    Transform MO-basis response tensor to AO basis using JAX.

    Args:
        mo_tensor: array with shape (O, V)
        C:         MO coefficients array with shape (B, B)
        S:         overlap matrix array with shape (B, B)
        n_occ:     number of occupied orbitals

    Returns:
        ao_tensor: array with shape (B, B)
    """
    C = S @ C
    C_occ = C[..., :n_occ]  # (B, O)
    C_virt = C[..., n_occ:]  # (B, V)
    ao_tensor = -jnp.einsum('...ia,...mi,...na->...mn', mo_tensor, C_occ, C_virt)
    if symmetrize:  # improves numerical stability
        ao_tensor = (ao_tensor + ao_tensor.swapaxes(-2, -1)) / 2
    return ao_tensor


@partial(jax.jit, static_argnames=['n_occ'])
def ao_to_mo(
    ao_tensor: FloatBxB,  # shape (B, B)
    C: FloatBxB,  # shape (B, B)
    n_occ: int,
) -> FloatOxV:
    """
    Project an AO-basis response tensor into the MO occupied-virtual block.

    Args:
        ao_tensor: array with shape (B, B)
        C:         array with shape (B, B)
        n_occ:     number of occupied orbitals

    Returns:
        mo_tensor: array with shape (O, V)
    """
    C_occ = C[..., :n_occ]  # (B, O)
    C_virt = C[..., n_occ:]  # (B, V)

    mo_tensor = -2 * jnp.einsum('...mn,...mi,...na->...ia', ao_tensor, C_occ, C_virt)
    return mo_tensor


@jax.jit
def expand_occupied_virtual_block_to_full_basis(tensor: FloatOxV) -> FloatBxB:
    O, V = tensor.shape  # noqa: E741
    return jnp.block(
        [
            [jnp.zeros((O, O)), tensor],
            [-tensor.T, jnp.zeros((V, V))],
        ]
    )


@jax.jit
def orbital_rotation(C: FloatBxB, angles: FloatOxV) -> FloatBxB:
    """
    Rotate the orbitals by the given angles.

    Args:
        C: the MO coefficients matrix
        angles: the angles to rotate the orbitals by.
            Note that only the occupied-virtual block is used as the density matrix
            is invariant to rotations of the occupied and virtual orbitals.

    Returns:
        the rotated MO coefficients matrix
    """
    theta = expand_occupied_virtual_block_to_full_basis(angles)
    return C @ expm(theta)


@jax.jit
def first_order_orbital_rotation(C: FloatBxB, angles: FloatOxV) -> FloatBxB:
    """
    First-order approximation of the orbital rotation.
    """
    theta = expand_occupied_virtual_block_to_full_basis(angles)
    return C @ (jnp.eye(C.shape[0]) + theta)


@partial(jax.jit, static_argnames=['n_occ'])
def dm_gradient_to_orbital_rotation_gradient(
    dX_dP: FloatBxB, C: FloatBxB, n_occ: int
) -> FloatOxV:
    """
    Compute the (unnormalized) occupied-virtual gradient direction on the Grassmann manifold,
    corresponding to steepest descent for minimization of the total energy or XC energy.

    Args:
        dX_dP: E_tot / E_xc gradient w.r.t. the density matrix,
            e.g. Fock or XC potential matrix (shape: [B, B])
        C: Molecular orbital coefficients (shape: [B, B])
        n_occ: Number of occupied orbitals

    Returns:
        Occ-Virt gradient (shape: [O, V]), i.e., direction in orbital rotation space
        that lowers the energy.
    """
    return ao_to_mo(dX_dP, C, n_occ)


@partial(jax.jit, static_argnames=['n_occ', 'normalize'])
def ao_density_perturbation_from_occupied_virtual_rotation(
    direction: FloatOxV,
    C: FloatBxB,
    S: FloatBxB,
    n_occ: UInt1,
    normalize: bool,
    epsilon: float = 1e-21,
) -> FloatBxB:
    """
    Perturb the density matrix along the direction specified on the Grassmannian.
    """
    if normalize:
        direction = direction / (jnp.linalg.norm(direction) + epsilon)
    out = 2 * mo_to_ao(direction, C, S, n_occ)
    return out


@jax.jit
def orthonormalize_delta_densities(
    delta_densities: FloatSCFxBxB,
) -> FloatSCFxBxB:
    """
    Orthogonalizes the delta densities to ensure they are linearly independent.
    """
    B, _ = delta_densities[0].shape
    delta_densities = einops.rearrange(delta_densities, 'SCF B1 B2 -> (B1 B2) SCF')
    # orthogonalize the delta densities
    Q, _ = jnp.linalg.qr(delta_densities, mode='reduced')
    return einops.rearrange(Q, '(B1 B2) SCF -> SCF B1 B2', B1=B, B2=B)
