import torch
from model import EnhancedCVAE
import config as cfg

def test_model_dimensions():
    print("Configuration:")
    print(f"INPUT_FRAMES: {cfg.INPUT_FRAMES}")
    print(f"OUTPUT_FRAMES: {cfg.OUTPUT_FRAMES}")
    print(f"NUM_JOINTS: {cfg.NUM_JOINTS}")
    print(f"CONDITION_DIM: {cfg.CONDITION_DIM}")
    print(f"TARGET_DIM: {cfg.TARGET_DIM}")

    batch_size = 32
    condition = torch.randn(batch_size, cfg.CONDITION_DIM)
    target = torch.randn(batch_size, cfg.TARGET_DIM)

    model = EnhancedCVAE(
        input_dim=cfg.TARGET_DIM,
        cond_dim=cfg.CONDITION_DIM,
        latent_dim=cfg.LATENT_DIM,
        hidden_dim=cfg.HIDDEN_DIM
    )

    param_count = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {param_count:,}")

    try:
        recon, mu, log_var = model(target, condition)
        print("Forward pass successful!")
        print(f"Input shapes: target={target.shape}, condition={condition.shape}")
        print(f"Output shapes: recon={recon.shape}, mu={mu.shape}, log_var={log_var.shape}")
        return True
    except Exception as e:
        print(f"Forward pass failed: {e}")
        return False

if __name__ == "__main__":
    test_model_dimensions()
