
import torch
import sys
import os

# Add directory to path
sys.path.append(os.path.abspath('XML Exp'))

try:
    from network_clifford import Network
    print("Successfully imported Network from network_clifford")
except ImportError as e:
    print(f"Import Error: {e}")
    sys.exit(1)

# Config
BATCH_SIZE = 5
NUM_LABELS = 10
IN_FEATURES = 100
HIDDEN = 50
OUT_FEATURES = 64 # 8x8 squares

# Initialize
device = torch.device('cpu')
print("Initializing Network...")
model = Network(device, IN_FEATURES, HIDDEN, OUT_FEATURES, NUM_LABELS)

# Test Forward
print("Testing Forward Pass...")
x = torch.randn(BATCH_SIZE, IN_FEATURES)
logits = model(x)
print(f"Logits Shape: {logits.shape}")

# Test Loss
print("Testing Loss...")
# Indices of true labels (e.g., first 2 items have label 1, next etc)
true_labels = torch.randint(0, NUM_LABELS, (BATCH_SIZE, 3)) # 3 labels per item
loss = model.loss(logits, true_labels)
print(f"Loss: {loss.item()}")

# Test Inference
print("Testing Inference...")
scores = model.inference(logits)
print(f"Scores Shape: {scores.shape} (Expected: {BATCH_SIZE}, {NUM_LABELS+1})")
# Note: we init with NUM_LABELS+1 in cls1/cls2 so bound_labels might be NUM_LABELS+1
