"""
Loss functions for direct minimization of the Kohn-Sham orbital rotation directions.

This module provides loss functions that measure the discrepancy between the predicted and target
orbital rotation directions in the context of direct minimization approaches to Kohn-Sham DFT.
The loss quantifies how well the predicted exchange-correlation (XC) potential (or its response)
produces orbital rotations that match a reference (e.g., from a higher-level functional or
from a converged SCF calculation).

These losses are used to train or evaluate functionals that aim to produce accurate
orbital rotation directions, which are crucial for achieving correct self-consistent
solutions in DFT and for learning functionals.
"""

from dataclasses import dataclass
from functools import partial
from typing import Literal

import jax
import jax.numpy as jnp

from deixc.orbital_transforms import dm_gradient_to_orbital_rotation_gradient
from egxc.training.loss.tensor import TensorLossConfig, TensorMeasures, tensor_loss
from egxc.utils.typing import (
    Float1,
    FloatB,
    FloatBxB,
    FloatOxV,
    FloatTx2xOxV,
    FloatTxOxV,
)


@dataclass(frozen=True, slots=True)
class OrbitalRotationTensorLossConfig(TensorLossConfig[TensorMeasures]):
    """
    oep_weighting: bool , if True, the loss is weighted by the energy diff
    """

    oep_weighting: bool


@dataclass(frozen=True, slots=True)
class OrbitalRotationHessianLossConfig(OrbitalRotationTensorLossConfig):
    """
    Normalization options:
        For **vector norms** (i.e. a single axis reduction):
            - ``ord=None`` (default) computes the 2-norm
            - ``ord=inf`` computes ``max(abs(x))``
            - ``ord=-inf`` computes min(abs(x))``
            - ``ord=0`` computes ``sum(x!=0)``
            - for other numerical values, computes ``sum(abs(x) ** ord)**(1/ord)``

        For **matrix norms** (i.e. two axes reductions):
            - ``ord='fro'`` or ``ord=None`` (default) computes the Frobenius norm
            - ``ord='nuc'`` computes the nuclear norm, or the sum of the singular values
            - ``ord=1`` computes ``max(abs(x).sum(0))``
            - ``ord=-1`` computes ``min(abs(x).sum(0))``
            - ``ord=2`` computes the 2-norm, i.e. the largest singular value
            - ``ord=-2`` computes the smallest singular value
    """

    n_perturbations: int
    differentiate_through_ground_state: bool
    normalization: None | Literal['fro', 'nuc'] | int


def orbital_energy_differences(
    orbital_energies: FloatB, O: int, eps: float = 1e-12
) -> FloatOxV:
    """
    Computes the pairwise differences between the occupied and virtual eigenvalues.
    eps is added to avoid division by zero for (near) degenerate ground states
    """
    eps_occ = orbital_energies[:O]
    eps_virt = orbital_energies[O:]
    delta_eps = jnp.abs(eps_occ[:, None] - eps_virt[None, :])
    return jnp.maximum(delta_eps, eps)


@partial(jax.jit, static_argnames=['n_occ', 'config'])
def orbital_rotation_gradient_loss(
    target_fock_or_xc_pot: FloatBxB,
    predicted_fock_or_xc_pot: FloatBxB,
    predicted_mo_coeffs: FloatBxB,
    n_occ: int,
    orbital_energies: FloatB,
    config: OrbitalRotationTensorLossConfig,
) -> Float1:
    """
    Computes the mean squared error of the elements of the Jacobian.

    Args:
        target_fock_or_xc_pot: Reference Fock or XC potential matrix (B, B).
        predicted_fock_or_xc_pot: Predicted Fock or XC potential matrix (B, B).
        predicted_mo_coeffs: Predicted molecular orbital coefficients (B, B).
        n_occ: Number of occupied orbitals.

    Returns:
        Scalar loss: squared L2 norm of the difference in orbital rotation gradients induced by the
            matrix difference.
    """
    delta_gradient = dm_gradient_to_orbital_rotation_gradient(
        predicted_fock_or_xc_pot - target_fock_or_xc_pot, predicted_mo_coeffs, n_occ
    )
    if config.oep_weighting:
        delta_eps = orbital_energy_differences(orbital_energies, n_occ)
        delta_gradient /= delta_eps
    return tensor_loss(delta_gradient, config)


@partial(jax.jit, static_argnames=['config'])
def orbital_rotation_hessian_loss(
    target_linear_responses: FloatTxOxV | FloatTx2xOxV,
    predicted_linear_responses: FloatTxOxV | FloatTx2xOxV,
    orbital_energies: FloatB,
    config: OrbitalRotationHessianLossConfig,
) -> Float1:
    """
    Computes the squared L2 loss between two sets of orbital rotation directions.

    Args:
        target_directions: Reference orbital rotation directions (O, V).
        predicted_directions: Predicted orbital rotation directions (O, V).
        ks_eigenvalues: eigenvalues of the GEV problem

    Returns:
        Scalar loss measuring the squared difference between target and predicted directions.
    """
    # normalize target linear responses
    target_norm = jnp.linalg.norm(
        target_linear_responses,
        ord=config.normalization,
        axis=(-2, -1),
    )[..., None, None]

    difference = target_linear_responses - predicted_linear_responses
    if config.oep_weighting:
        O = target_linear_responses.shape[-2]
        delta_eps = orbital_energy_differences(orbital_energies, O)
        difference /= delta_eps
    difference /= target_norm + 1e-15

    out = jax.vmap(tensor_loss, in_axes=(0, None))(difference, config)
    return jnp.mean(out)
