"""
Simple test script to verify MDAG implementation
"""

import torch
import numpy as np
from mdag_grouping import (
    compute_viewpoint_diversity,
    orthogonal_grouping,
    compute_dynamic_threshold,
    mdag_grouping_strategy,
    extract_positions_from_pairwise_matrix
)


def test_viewpoint_diversity():
    """Test viewpoint diversity matrix computation"""
    print("=" * 60)
    print("Test 1: Viewpoint Diversity Matrix")
    print("=" * 60)
    
    # Create sample positions: 4 vehicles around a center point
    # Center point will be at (0, 0) (the centroid)
    # Vehicle 0 (ego) to the left
    # Vehicle 1 to the right (opposite to ego, 180°)
    # Vehicle 2 above (orthogonal to ego, 90°)
    # Vehicle 3 below (orthogonal to ego, 90°)
    positions = torch.tensor([
        [-10.0, 0.0],   # Ego (left of center)
        [10.0, 0.0],    # Right (opposite to ego)
        [0.0, 10.0],    # Top (orthogonal to ego)
        [0.0, -10.0]    # Bottom (orthogonal to ego)
    ], dtype=torch.float32)
    
    # Center point is at origin (0, 0) which is the centroid
    center = torch.tensor([0.0, 0.0], dtype=torch.float32)
    
    diversity_matrix = compute_viewpoint_diversity(positions, center)
    
    print("Positions:")
    print(positions)
    print(f"\nCenter Point: {center}")
    print("\nViewpoint Diversity Matrix (in radians):")
    print(diversity_matrix)
    print("\nViewpoint Diversity Matrix (in degrees):")
    print(diversity_matrix * 180 / np.pi)
    
    # Expected: 
    # - Ego (left) and Vehicle 1 (right): vectors point in opposite directions → 180° (π)
    # - Ego (left) and Vehicle 2 (top): vectors are perpendicular → 90° (π/2)
    # - Ego (left) and Vehicle 3 (bottom): vectors are perpendicular → 90° (π/2)
    # - Vehicle 2 (top) and Vehicle 3 (bottom): vectors point in opposite directions → 180° (π)
    
    print("\nExpected angles:")
    print(f"  Ego-Right: ~180° (actual: {diversity_matrix[0, 1] * 180 / np.pi:.1f}°)")
    print(f"  Ego-Top: ~90° (actual: {diversity_matrix[0, 2] * 180 / np.pi:.1f}°)")
    print(f"  Ego-Bottom: ~90° (actual: {diversity_matrix[0, 3] * 180 / np.pi:.1f}°)")
    print(f"  Top-Bottom: ~180° (actual: {diversity_matrix[2, 3] * 180 / np.pi:.1f}°)")
    
    print("\n✓ Test passed\n")


def test_orthogonal_grouping():
    """Test orthogonal adversarial grouping"""
    print("=" * 60)
    print("Test 2: Orthogonal Adversarial Grouping")
    print("=" * 60)
    
    # Create diversity matrix
    positions = torch.tensor([
        [0.0, 0.0],    # Ego
        [10.0, 0.0],   # Front
        [0.0, 10.0],   # Side
        [-10.0, 0.0]   # Behind
    ], dtype=torch.float32)
    
    diversity_matrix = compute_viewpoint_diversity(positions)
    
    group1, group2 = orthogonal_grouping(diversity_matrix, agent_num=4, ego_idx=0)
    
    print(f"Group 1 (Strong Adversarial): {group1}")
    print(f"Group 2 (Weak Adversarial): {group2}")
    
    # Expected: Group 1 should contain ego and the vehicle with max diversity (side or behind)
    print("\n✓ Test passed\n")


def test_dynamic_threshold():
    """Test dynamic threshold computation"""
    print("=" * 60)
    print("Test 3: Dynamic Adaptive Threshold")
    print("=" * 60)
    
    # Create sample features with different similarity patterns
    # Case 1: High similarity (all features similar)
    features_similar = torch.randn(4, 64, 32, 32)
    features_similar = features_similar + torch.randn(1, 64, 32, 32) * 0.1  # Add small noise
    
    threshold_similar = compute_dynamic_threshold(
        features_similar, 
        base_threshold=0.5, 
        sensitivity=1.0
    )
    
    print(f"Case 1 - High Similarity Features:")
    print(f"  Dynamic Threshold: {threshold_similar:.4f}")
    print(f"  Expected: Higher than base (0.5) due to high similarity")
    
    # Case 2: High diversity (features very different)
    features_diverse = torch.randn(4, 64, 32, 32) * 10.0  # Large random variations
    
    threshold_diverse = compute_dynamic_threshold(
        features_diverse,
        base_threshold=0.5,
        sensitivity=1.0
    )
    
    print(f"\nCase 2 - High Diversity Features:")
    print(f"  Dynamic Threshold: {threshold_diverse:.4f}")
    print(f"  Expected: Lower than base (0.5) due to high diversity")
    
    print("\n✓ Test passed\n")


def test_mdag_full_pipeline():
    """Test full MDAG grouping strategy"""
    print("=" * 60)
    print("Test 4: Full MDAG Pipeline")
    print("=" * 60)
    
    # Create sample data
    positions = torch.tensor([
        [0.0, 0.0],
        [10.0, 0.0],
        [0.0, 10.0],
        [-10.0, 0.0]
    ], dtype=torch.float32)
    
    features = torch.randn(4, 64, 32, 32)
    
    # Test with static threshold
    print("Test 4a: MDAG with Static Threshold")
    result_static = mdag_grouping_strategy(
        positions=positions,
        features=features,
        ego_idx=0,
        use_dynamic_threshold=False,
        base_threshold=0.5
    )
    
    print(f"  Group 1: {result_static['group1']}")
    print(f"  Group 2: {result_static['group2']}")
    print(f"  Threshold: {result_static['threshold']:.4f}")
    
    # Test with dynamic threshold
    print("\nTest 4b: MDAG with Dynamic Threshold")
    # Create sample CPS scores for testing (excluding ego, so 3 scores for 3 non-ego agents)
    sample_cps_scores = [0.6, 0.7, 0.65]  # Example CPS scores for agents 1, 2, 3
    result_dynamic = mdag_grouping_strategy(
        positions=positions,
        features=features,
        ego_idx=0,
        use_dynamic_threshold=True,
        base_threshold=0.5,
        sensitivity=1.0,
        cps_scores=sample_cps_scores
    )
    
    print(f"  Group 1: {result_dynamic['group1']}")
    print(f"  Group 2: {result_dynamic['group2']}")
    print(f"  Dynamic Threshold: {result_dynamic['threshold']:.4f}")
    
    print("\n✓ Test passed\n")


def test_extract_positions():
    """Test position extraction from pairwise transformation matrix"""
    print("=" * 60)
    print("Test 5: Extract Positions from Pairwise Matrix")
    print("=" * 60)
    
    # Create sample pairwise transformation matrix
    n = 4
    pairwise_t_matrix = torch.eye(4).unsqueeze(0).unsqueeze(0).repeat(n, n, 1, 1)
    
    # Set translation components
    pairwise_t_matrix[0, 0, 0, 3] = 0.0   # Ego x
    pairwise_t_matrix[0, 0, 1, 3] = 0.0   # Ego y
    pairwise_t_matrix[0, 1, 0, 3] = 10.0  # Vehicle 1 x
    pairwise_t_matrix[0, 1, 1, 3] = 0.0   # Vehicle 1 y
    pairwise_t_matrix[0, 2, 0, 3] = 0.0   # Vehicle 2 x
    pairwise_t_matrix[0, 2, 1, 3] = 10.0  # Vehicle 2 y
    pairwise_t_matrix[0, 3, 0, 3] = -10.0 # Vehicle 3 x
    pairwise_t_matrix[0, 3, 1, 3] = 0.0   # Vehicle 3 y
    
    positions = extract_positions_from_pairwise_matrix(pairwise_t_matrix)
    
    print("Extracted Positions:")
    print(positions)
    
    expected = torch.tensor([
        [0.0, 0.0],
        [10.0, 0.0],
        [0.0, 10.0],
        [-10.0, 0.0]
    ])
    
    assert torch.allclose(positions, expected, atol=1e-5), "Position extraction failed!"
    
    print("\n✓ Test passed\n")


if __name__ == "__main__":
    print("\n" + "=" * 60)
    print("MDAG Implementation Tests")
    print("=" * 60 + "\n")
    
    try:
        test_viewpoint_diversity()
        test_orthogonal_grouping()
        test_dynamic_threshold()
        test_extract_positions()
        test_mdag_full_pipeline()
        
        print("=" * 60)
        print("All Tests Passed! ✓")
        print("=" * 60)
        
    except Exception as e:
        print(f"\n❌ Test failed with error: {e}")
        import traceback
        traceback.print_exc()

