"""
Debug the actual model output structure to understand the correct slicing.
"""

import torch
import numpy as np
from dataset_generator import load_mnist_rotation_datasets
from oa_infer import OAWrapper

def debug_model_output():
    """Debug the actual model output structure."""
    
    print("Loading OA model...")
    wrapper = OAWrapper(model_size='base')
    
    # Get a single sample
    _, test_loader = load_mnist_rotation_datasets(
        rotation_range=(0.0, 360.0), 
        augmentation_factor=1, 
        batch_size=1, 
        seed=42
    )
    
    images, angles_0_360 = next(iter(test_loader))
    image = images[0]
    
    print(f"Ground truth angle: {angles_0_360[0].item():.2f}°")
    
    # Convert to PIL and process
    pil = wrapper._mnist_tensor_to_rgb_pil(image)
    pil = wrapper.to_rgb_224(pil)
    
    # Get raw model output
    image_inputs = wrapper.processor(images=pil)
    image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(wrapper.device)
    
    with torch.no_grad():
        dino_pred = wrapper.model(image_inputs)
    
    print(f"Raw model output shape: {dino_pred.shape}")
    print(f"Total output size: {dino_pred.shape[1]}")
    
    # Check the structure - OA outputs: azimuth(360) + polar(180) + rotation(360) + confidence(2)
    # But let's see what we actually get
    total_size = dino_pred.shape[1]
    print(f"Expected: 360 + 180 + 360 + 2 = 902")
    print(f"Actual: {total_size}")
    
    # Try to figure out the actual structure
    if total_size == 722:  # What we're seeing
        print("Model has 722 outputs")
        print("Possible structure:")
        print("  - Azimuth: 0-359 (360)")
        print("  - Polar: 360-539 (180)")  
        print("  - Rotation: 540-721 (182)")
        print("  - Confidence: 722-723 (2)")
        
        # Test the rotation head
        rotation_head = dino_pred[:, 540:722]  # 182 elements
        print(f"Rotation head shape: {rotation_head.shape}")
        print(f"Rotation head range: [{rotation_head.min():.3f}, {rotation_head.max():.3f}]")
        
        # Check if this is actually 180 elements (0-179 degrees)
        if rotation_head.shape[1] == 182:
            print("Rotation head has 182 elements - likely 0-181 degrees")
            # Try using 180 elements instead
            rotation_head_180 = dino_pred[:, 540:720]  # 180 elements
            print(f"Using 180 elements: {rotation_head_180.shape}")
            
            # Test argmax
            max_idx = torch.argmax(rotation_head_180, dim=-1)
            print(f"Argmax index: {max_idx.item()}")
            print(f"Argmax as angle: {max_idx.item() - 90}°")  # Convert 0-179 to -90 to +89

if __name__ == "__main__":
    debug_model_output()
