from functools import partial
from typing import Literal, Tuple, overload

import flax.linen as nn
import jax
import jax.numpy as jnp

from egxc.solver import fock
from egxc.solver.scf.diis import DiisState, diis_update
from egxc.systems.base import System
from egxc.utils import linalg
from egxc.utils.typing import (
    Float2xBxB,
    FloatBxB,
    FloatSCF,
    FloatSCFx2xBxB,
    FloatSCFxBxB,
)
from egxc.xc_energy import XCModule

ConvAccState = DiisState | Tuple[FloatBxB | Float2xBxB, int] | None

ScfCycleCarry = Tuple[
    FloatBxB | Float2xBxB,
    System,
    ConvAccState,
]


class DerivativeInformedSolver[SR: Literal[True] | Literal[False]](nn.Module):
    """
    Abstract base class for all solvers.
    """

    xc_module: XCModule  # only this module contains trainable parameters
    spin_restricted: SR

    @overload
    def __call__(
        self: 'DerivativeInformedSolver[Literal[True]]',
        initial_density_matrix: FloatBxB,
        sys: System,
    ) -> Tuple[
        Tuple[FloatSCF, FloatSCF],
        Tuple[FloatSCFxBxB, FloatSCFxBxB, FloatSCFxBxB, FloatSCFxBxB],
    ]: ...
    @overload
    def __call__(
        self: 'DerivativeInformedSolver[Literal[False]]',
        initial_density_matrix: Float2xBxB,
        sys: System,
    ) -> Tuple[
        Tuple[FloatSCF, FloatSCF],
        Tuple[FloatSCFx2xBxB, FloatSCFx2xBxB, FloatSCFx2xBxB, FloatSCFx2xBxB],
    ]: ...
    def __call__(self, initial_density_matrix, sys):
        raise NotImplementedError('abstract method')

    def xc_potential_and_linear_response(self, *args, **kw):  # forwarded from xc_module
        return self.xc_module.xc_potential_and_linear_response(*args, **kw)

    def xc_potential_linear_responses(self, *args, **kw):  # forwarded from xc_module
        return self.xc_module.xc_potential_linear_responses(*args, **kw)

    def xc_rotation_hvp(self, *args, **kw):  # forwarded from xc_module
        return self.xc_module.xc_rotation_hvp(*args, **kw)


class DerivativeInformedSelfConsistentFieldSolver(DerivativeInformedSolver):
    cycles: int
    use_density_fitting: bool
    convergence_acceleration_method: Literal['Vanilla', 'Momentum', 'DIIS'] = 'DIIS'

    def setup(self) -> None:
        self.FockModule = fock.FockMatrix(
            self.xc_module, self.use_density_fitting, self.spin_restricted
        )  # type: ignore
        # set up specified convergence acceleration method to dampen oscillations of the fock matrix
        if self.convergence_acceleration_method == 'DIIS':
            init_fn = partial(DiisState.init, self.cycles)
            if not self.spin_restricted:  # vmap over spin
                self.convergence_acc_fn = jax.vmap(
                    diis_update, in_axes=(None, 0, 0, 0, None)
                )
                init_fn = jax.vmap(init_fn, in_axes=(0, 0, None))
            else:
                self.convergence_acc_fn = diis_update
            self.init_convergence_acc_state = init_fn
        elif self.convergence_acceleration_method == 'Momentum':
            self.init_convergence_acc_state = lambda F, *args: F

            def update_fn(cycle, F_raw, state, *args):
                F_previous = state
                alpha = 0.3**cycle + 0.3  # FIXME: make this a hyperparameter
                F = alpha * F_raw + (1 - alpha) * F_previous
                return F, F

            self.convergence_acc_fn = update_fn
        elif self.convergence_acceleration_method == 'Vanilla':
            # Default to vanilla SCF
            self.init_convergence_acc_state = lambda *args: None
            self.convergence_acc_fn = lambda _, F, *args: (F, None)
        else:
            raise ValueError(
                f'Invalid convergence acceleration method: {self.convergence_acceleration_method}'
            )

        def new_density_matrix_and_molecular_coefficients(F, X, occupancies):
            _, C = linalg.modified_generalized_eigenvalue_problem(F, X)
            return linalg.coeffs_to_density_matrix(C, occupancies), C

        if not self.spin_restricted:  # vmap over spin
            new_density_matrix_and_molecular_coefficients = jax.vmap(
                new_density_matrix_and_molecular_coefficients, in_axes=(0, None, 0)
            )
        self.new_dm_and_mo_coeffs = new_density_matrix_and_molecular_coefficients

    def __call__(self, initial_density_matrix: FloatBxB | Float2xBxB, sys: System):
        initial_fock_matrix = self.FockModule.fock_matrix(
            sys._nuc_pos, initial_density_matrix, sys
        )
        energies, (mo_coeffs, density_matrices, fock_matrices, v_xc_matrices) = (
            self.scf_loop(initial_fock_matrix, initial_density_matrix, sys)
        )
        return energies, (mo_coeffs, density_matrices, fock_matrices, v_xc_matrices)

    def scf_loop(
        self, F_0: FloatBxB | Float2xBxB, P_0: FloatBxB | Float2xBxB, sys: System
    ):
        """
        DIIS loop for SCF convergence.
        Args:
            F_0: Initial Fock matrix
            P_0: Initial density matrix
            cst: Constant system tensors
            sys: System
        Returns:
            Energies: Array of energies for each cycle (total_cycles)
            Density matrices: Array of density matrices for each cycle (total_cycles, N_bas, N_bas)
            V_xc matrices: Array of V_xc matrices for each cycle (total_cycles, N_bas, N_bas)
        """

        def loop_body(
            carry: ScfCycleCarry, cycle: int
        ) -> Tuple[
            ScfCycleCarry,
            Tuple[FloatBxB, FloatBxB, FloatBxB, FloatBxB]
            | Tuple[Float2xBxB, Float2xBxB, Float2xBxB, Float2xBxB],
        ]:
            F, sys, acc_state = carry
            P, C = self.new_dm_and_mo_coeffs(
                F, sys.fock_tensors.diagonal_overlap, sys.fock_tensors.occupancies
            )
            H_core, J, V_xc = self.FockModule.fock_matrix_contributions(
                sys._nuc_pos, P, sys
            )
            F_new_raw = H_core + J + V_xc
            F_new_damped, acc_state = self.convergence_acc_fn(
                cycle,
                F_new_raw,
                acc_state,  # type: ignore
                P,
                sys.fock_tensors,
            )
            return (F_new_damped, sys, acc_state), (C, P, F_new_damped, V_xc)

        acc_state = self.init_convergence_acc_state(F_0, P_0, sys.fock_tensors)
        init_state = (F_0, sys, acc_state)

        _, (mo_coeffs, density_matrices, fock_matrices, v_xc_matrices) = jax.lax.scan(
            loop_body,
            init_state,
            xs=jnp.arange(self.cycles),  # type: ignore
        )
        energies = self.__calc_energies_along_scf_trajectory(
            sys._nuc_pos, density_matrices, sys
        )
        return energies, (mo_coeffs, density_matrices, fock_matrices, v_xc_matrices)

    def __calc_energies_along_scf_trajectory(self, nuc_pos, density_matrices, sys):
        energy_fn = jax.vmap(self.FockModule.energy, in_axes=(None, 0, None))
        return energy_fn(nuc_pos, density_matrices, sys)
