# test_model_dimensions.py
import torch
from model import EnhancedCVAE
import config as cfg

def test_model_dimensions():
    """Test whether model dimensions match"""
    print("Configuration:")
    print(f"INPUT_FRAMES: {cfg.INPUT_FRAMES}")
    print(f"OUTPUT_FRAMES: {cfg.OUTPUT_FRAMES}")
    print(f"NUM_JOINTS: {cfg.NUM_JOINTS}")
    print(f"CONDITION_DIM: {cfg.CONDITION_DIM}")
    print(f"TARGET_DIM: {cfg.TARGET_DIM}")
    
    batch_size = 32
    condition = torch.randn(batch_size, cfg.CONDITION_DIM)
    target = torch.randn(batch_size, cfg.TARGET_DIM)
    
    model = EnhancedCVAE(
        input_dim=cfg.TARGET_DIM,
        cond_dim=cfg.CONDITION_DIM,
        latent_dim=cfg.LATENT_DIM,
        hidden_dim=cfg.HIDDEN_DIM
    )
    
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Number of model parameters: {param_count:,}")
    
    try:
        recon, mu, log_var = model(target, condition)
        print("Forward pass successful!")
        print(f"Input shapes: target={target.shape}, condition={condition.shape}")
        print(f"Output shapes: recon={recon.shape}, mu={mu.shape}, log_var={log_var.shape}")
        return True
    except Exception as e:
        print(f"Forward pass failed: {e}")
        return False

if __name__ == "__main__":
    test_model_dimensions()
