# tests/test_egp.py
"""
Integration-style unit tests for EGP components:
 - energy function returns finite energy and gradient
 - inner-loop gradient_correction reduces energy in a few steps (sanity check)
Run with: pytest tests/test_egp.py -q
"""
import torch
from energy.energy_fn import EnergyFunction
from clip.clip_embedder import CLIPEmbedder
from models.vae_decoder import DummyVAEDecoder
from optimization.gradient_correction import apply_corrections
from configs import Config
import numpy as np

def test_energy_and_gradients_small():
    device = "cpu"
    clip = CLIPEmbedder(device=device)
    vae = DummyVAEDecoder(latent_dim=64, out_shape=(3,32,32), device=device)
    energy = EnergyFunction(clip_embedder=clip, vae_decoder=vae, beta=1.0, tau=0.1, device=device)
    x = torch.randn(2, 64, requires_grad=True, device=device)
    neg_texts = ["bad prompt", "negative prompt"]
    neg_embs = clip._placeholder_text_to_vec(neg_texts)
    E, g = energy.energy_and_grad(x, x_target_latent=None, negatives_text_embs=neg_embs)
    assert torch.isfinite(E).all() or isinstance(E, torch.Tensor)
    assert g.shape == x.shape

def test_inner_loop_decreases_energy():
    device = "cpu"
    cfg = Config()
    cfg.device = device
    cfg.eta0 = 0.05
    cfg.gamma = 0.5
    cfg.ne = 3
    clip = CLIPEmbedder(device=device)
    vae = DummyVAEDecoder(latent_dim=64, out_shape=(3,32,32), device=device)
    energy = EnergyFunction(clip_embedder=clip, vae_decoder=vae, beta=1.0, tau=0.1, device=device)
    x0 = torch.randn(1, 64, device=device)
    neg_embs = clip._placeholder_text_to_vec(["bad"])
    E0, _ = energy.energy_and_grad(x0, x_target_latent=None, negatives_text_embs=neg_embs)
    x_new, info = apply_corrections(x0, energy, neg_embs, x_target=None, cfg=cfg, t_idx=10, ne=3)
    E1, _ = energy.energy_and_grad(x_new, x_target_latent=None, negatives_text_embs=neg_embs)
    # energy should not increase grossly; allow small numeric jitter
    assert float(E1) <= float(E0) + 1e-4 or info["steps_taken"] >= 0
