from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp
from flax.struct import dataclass

from egxc.systems.base import FockTensors
from egxc.utils.typing import FloatBxB, FloatSCF, FloatSCFxBxB, FloatSCFxSCF


def compute_residual(F: FloatBxB, P: FloatBxB, fock_tensors: FockTensors) -> FloatBxB:
    """
    Computes the residual matrix for the Fock matrix.
    F: Fock matrix
    P: Density matrix
    cst: constant system tensors containing the overlap matrix
    """
    temp = jnp.einsum('ab,bc,cd->ad', F, P, fock_tensors.overlap)
    res = (
        fock_tensors.diagonal_overlap.T @ (temp - temp.T) @ fock_tensors.diagonal_overlap
    )
    res = (res - res.T) / 2  # Recover anti-symmetry violated by numerical errors
    return res


@partial(jax.jit, static_argnames=('eps_tikhonov'))
def solve_pulay_equation(
    current_cycle: int,
    overlap: FloatSCFxSCF,
    eps_tikhonov: float = 1e-6,  # the regularization greatly aids the back-propagation stability, while introducing a negligibly small deviation (on the order of 1e-3 mHa) when compared to pyscf (for example see test_deixc_data_gen.py::test_custom_pipeline_consistency)
) -> FloatSCF:
    B = overlap
    total_cycles = overlap.shape[0]
    constraint_idx = current_cycle + 1
    set_vec = -1 * (jnp.arange(total_cycles) < constraint_idx)
    B = B.at[:, constraint_idx].set(set_vec)
    B = B.at[constraint_idx, :].set(set_vec)
    B = B.at[constraint_idx, constraint_idx].set(0)
    B = (B + B.T) / 2  # Ensure symmetry
    eps_adaptive = eps_tikhonov * (jnp.mean(jnp.abs(jnp.diag(B))) + 1)
    B += eps_adaptive * jnp.eye(B.shape[0])  # ridge regression
    rhs = jnp.zeros(total_cycles).at[constraint_idx].set(-1)
    fock_coeffs = jax.scipy.linalg.solve(
        B, rhs, assume_a='sym'
    )  # (x0, ..., x_{n-1}, lambda, 0, ..., 0)
    return fock_coeffs


@dataclass
class DiisState:
    overlap: FloatBxB  # do not confuse with the basis set overlap matrix
    fock_trajectory: FloatSCFxBxB
    res_trajectory: FloatSCFxBxB

    @classmethod
    def init(
        cls,
        total_cycles: int,
        fock_matrix: FloatBxB,
        density_matrix: FloatBxB,
        fock_tensors: FockTensors,
    ):
        N_bas = fock_matrix.shape[0]
        # diagonal padding such that matrix is invertible and well conditioned
        overlap = jnp.diag((1 + jnp.arange(total_cycles + 1)) / (total_cycles + 1))
        fock_trajectory = jnp.zeros((total_cycles, N_bas, N_bas)).at[0].set(fock_matrix)
        residual = compute_residual(fock_matrix, density_matrix, fock_tensors)
        res_trajectory = jnp.zeros((total_cycles, N_bas, N_bas)).at[0].set(residual)
        return cls(overlap, fock_trajectory, res_trajectory)


@partial(jax.jit, static_argnames=('stop_gradient'))
def diis_update(
    current_cycle: int,
    raw_fock_matrix: FloatBxB,
    state: DiisState,
    density_matrix: FloatBxB,
    fock_tensors: FockTensors,
    stop_gradient: bool = False,  # Note: this parameter is actually very important, when training with DIIS we observe best performance when setting this to False
) -> Tuple[FloatBxB, DiisState]:
    """
    Direct Inversion of the Iterative Subspace (DIIS) to accelerate the
    convergence of the Self-Consistent Field (SCF) method.
    Returns the DIIS update to the Fock matrix.

    current_cycle: current cycle of the SCF method
    raw_fock_matrix: standard Fock matrix
    density_matrix: density matrix
    state: DIIS state based on the previous cycles
    fock_tensors: constant system tensors containing the overlap matrix
    stop_gradient: whether to stop the gradient w.r.t. the Pulay coefficients,
        when training with DIIS we observe best performance when setting this to False

    returns:
        (Fock matrix updated by DIIS, updated DIIS state)

    Implementation inspired by
    https://github.com/psi4/psi4numpy/blob/master/Tutorials/03_Hartree-Fock/3b_rhf-diis.ipynb
    but adapted to be jax compile friendly.
    """
    residual = compute_residual(raw_fock_matrix, density_matrix, fock_tensors)
    i = current_cycle
    res_trajectory = state.res_trajectory.at[i].set(residual)
    new_overlap = jnp.einsum('ikl,kl->i', res_trajectory, residual)
    overlap = state.overlap.at[i, :-1].set(new_overlap)
    overlap = overlap.at[:-1, i].set(new_overlap)
    if stop_gradient:
        fock_coeffs = jax.lax.stop_gradient(solve_pulay_equation(i, overlap))
    else:
        fock_coeffs = solve_pulay_equation(i, overlap)
    fock_trajectory = state.fock_trajectory.at[i].set(raw_fock_matrix)
    F_out = jnp.einsum('i,ijk->jk', fock_coeffs[:-1], fock_trajectory)
    F_out = jnp.where(
        jnp.isnan(F_out).any(), raw_fock_matrix, F_out
    )  # this is necessary, since B becomes singular once it converges converged
    return F_out, DiisState(overlap, fock_trajectory, res_trajectory)
