import einops
import jax.numpy as jnp
import pytest
from utils import PyscfSystemWrapper, set_jax_testing_config

from egxc.systems.examples import get
from egxc.training.loss import (
    DensityFieldLossConfig,
    LossConfig,
    RelativeLossWeights,
    get_loss_fns,
)

set_jax_testing_config()


@pytest.mark.quick
def test_density_loss():
    n_cycles = 10
    loss_config = LossConfig(
        RelativeLossWeights(energy=0, density=1.0),
        jnp.ones(n_cycles),
        max_energy_volatility=jnp.inf,
        density=DensityFieldLossConfig(
            measure='mse',
            scale_per_electron=True,
            spin_restricted=True,
            is_density_fitted=True,
        ),
        reference_basis_is_same=True,
    )
    loss_fns = get_loss_fns(loss_config)

    sys = get('water', basis='6-31G(d)', alignment=1)
    test = PyscfSystemWrapper(sys, basis='6-31G(d)')
    P0 = test.initial_density_matrix
    P1 = test.density_matrix

    P1_predicted = einops.repeat(P1, 'i j -> scf i j', scf=n_cycles)

    # Test density loss
    density_loss = loss_fns.density(
        P1,  # type: ignore
        P1_predicted,  # type: ignore
        sys,
    )
    assert density_loss == 0.0

    density_loss = loss_fns.density(
        P0,  # type: ignore
        P1_predicted,  # type: ignore
        sys,
    )
    assert density_loss > 0.0
