import torch
from utils.lightning_module import FlowMatchingLightningModule
from omegaconf import OmegaConf

def test_residual_finetune():
    # Create a minimal config
    cfg = OmegaConf.create({
        'env': {'name': 'pusht'},
        'execution': {
            'mode': 'res_finetune',
            'pretrain_checkpoint': None,
            'solver': {'time_steps': 5}
        },
        'training': {
            'learning_rate': 0.001,
            'epochs': 10
        }
    })
    
    # Create the lightning module
    model = FlowMatchingLightningModule(cfg)
    
    # Create a dummy batch
    obs_cond = torch.randn(2, 8)
    x_traj = torch.randn(2, 8)
    batch = (obs_cond, x_traj)
    
    # Test residual training step
    loss = model.residual_training_step(batch, 0)
    print(f"Residual finetuning loss: {loss.item():.4f}")
    assert not torch.isnan(loss), "Loss should not be NaN"
    
    # Test model calling
    x = torch.randn(2, 8)
    t = torch.tensor([0.5])
    output = model(x, t)
    print(f"Model output shape: {output.shape}")
    
    print("All tests passed!")

if __name__ == "__main__":
    test_residual_finetune()
