"""Test the EEG dataloader integration."""

import torch
import torch.nn as nn
from src.data import EEGDataLoader
from src.utils import DataAttr
from src.models import AmortizedConditioningEngine
from src.models.modules import Embedder, Head, MultiChannelMixtureGaussian, build_mlp_with_linear_skipcon

def test_eeg_dataloader():
    """Test that EEG dataloader works with existing interface."""
    
    print("Testing EEG DataLoader Integration")
    print("=" * 50)
    
    # Create dataloader
    dataloader = EEGDataLoader(
        subset="train",
        mode="random",
        total_points=256,
        device="cpu",
        seed=42
    )
    
    # Test single batch generation
    print("\n1. Testing single batch generation:")
    context_target_batch, buffer_batch = dataloader(
        problem=None,  # Not used for EEG
        batch_size=8,
        num_register_points=64,  # nc
        num_latent=0,  # Not used
        min_register_points=32,
        max_register_points=128,
        x_range=(-2, 2),  # Not used
        max_buffer_points=8,
        num_target_partitions=1,  # Not used
        num_target_data_per_partition=100,  # Not used
    )
    
    print(f"\nContext-Target Batch:")
    print(f"  xc shape: {context_target_batch.xc.shape}")
    print(f"  yc shape: {context_target_batch.yc.shape}")
    print(f"  xt shape: {context_target_batch.xt.shape}")
    print(f"  yt shape: {context_target_batch.yt.shape}")
    print(f"  loss_mask shape: {context_target_batch.loss_mask.shape}")
    
    print(f"\nBuffer Batch:")
    print(f"  xc shape: {buffer_batch.xc.shape}")
    print(f"  yc shape: {buffer_batch.yc.shape}")
    print(f"  xt: {buffer_batch.xt}")
    print(f"  yt: {buffer_batch.yt}")
    
    # Test generator creation
    print("\n2. Testing generator creation:")
    generator = dataloader.create_generator(
        batch_size=16,
        num_batches=10,
        num_register_points="random",
        min_register_points=32,
        max_register_points=128,
    )
    
    print(f"Generator created with {len(generator)} batches")
    
    # Test a few batches from generator
    print("\n3. Testing generator iteration:")
    for i, batch in enumerate(generator):
        if i >= 3:
            break
        print(f"\nBatch {i}:")
        print(f"  xc shape: {batch.xc.shape}")
        print(f"  yc shape: {batch.yc.shape}")
        print(f"  xb shape: {batch.xb.shape}")
        print(f"  yb shape: {batch.yb.shape}")
        print(f"  xt shape: {batch.xt.shape}")
        print(f"  yt shape: {batch.yt.shape}")
        
        # Check nc value
        nc = batch.xc.shape[1]
        print(f"  Context size (nc): {nc}")
    
    print("\n4. Testing dimension consistency:")
    # Verify all y dimensions are 7
    assert context_target_batch.yc.shape[-1] == 7, "Context should have 7 channels"
    assert context_target_batch.yt.shape[-1] == 7, "Target should have 7 channels"
    assert buffer_batch.yc.shape[-1] == 7, "Buffer should have 7 channels"
    print("✓ All y dimensions are consistent (7 channels)")
    
    # Verify buffer size is 8
    assert buffer_batch.yc.shape[1] == 8, "Buffer should have 8 points"
    print("✓ Buffer size is fixed at 8")
    
    # Verify total points
    total = context_target_batch.xc.shape[1] + buffer_batch.xc.shape[1] + context_target_batch.xt.shape[1]
    assert total == 256, f"Total points should be 256, got {total}"
    print(f"✓ Total points: {total}")
    
    print("\n✅ All tests passed!")


def test_ace_forward_pass():
    """Test forward pass through ACE model with EEG data."""
    
    print("\n\nTesting ACE Model Forward Pass with EEG Data")
    print("=" * 50)
    
    # Model configuration for EEG
    dim_x = 1  # Time dimension
    dim_y = 7  # 7 EEG channels
    hidden_dim = 128
    dim_model = 128
    num_layers = 4
    num_heads = 8
    num_mixture_components = 20
    feedforward_dim = 512
    device = "cpu"
    
    print(f"\nModel Configuration:")
    print(f"  dim_x: {dim_x}")
    print(f"  dim_y: {dim_y}")
    print(f"  hidden_dim: {hidden_dim}")
    print(f"  dim_model: {dim_model}")
    print(f"  num_layers: {num_layers}")
    print(f"  num_heads: {num_heads}")
    
    # Create embedder
    embedder = Embedder(
        dim_x=dim_x,
        dim_y=dim_y,
        hidden_dim=hidden_dim,
        out_dim=dim_model,
        depth=2,
        mlp_builder=build_mlp_with_linear_skipcon,
        pos_emb_init=False,
    )
    
    # Create backbone (transformer encoder)
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=dim_model,
        nhead=num_heads,
        dim_feedforward=feedforward_dim,
        activation="gelu",
        batch_first=True,
    )
    backbone = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    backbone.dim_model = dim_model  # Add dim_model attribute
    
    # Note: MixtureGaussian doesn't support multi-dimensional outputs
    # For EEG with 7 channels, we test with raw Head module instead
    
    print("\n✓ Model created successfully")
    
    # Create dataloader
    dataloader = EEGDataLoader(
        subset="train",
        mode="forecasting",  # Use forecasting for clear temporal structure
        total_points=256,
        device=device,
        seed=42
    )
    
    # Generate a batch
    context_target_batch, buffer_batch = dataloader(
        problem=None,
        batch_size=4,
        num_register_points=32,
        num_latent=0,
        min_register_points=32,
        max_register_points=128,
        x_range=(-2, 2),
        max_buffer_points=8,
        num_target_partitions=1,
        num_target_data_per_partition=100,
    )
    
    print("\n✓ EEG batch generated")
    print(f"  Context: {context_target_batch.xc.shape}")
    print(f"  Buffer: {buffer_batch.xc.shape}")
    print(f"  Target: {context_target_batch.xt.shape}")
    
    # Combine batches for ACE model
    combined_batch = DataAttr()
    combined_batch.xc = context_target_batch.xc
    combined_batch.yc = context_target_batch.yc
    combined_batch.xb = buffer_batch.xc
    combined_batch.yb = buffer_batch.yc
    combined_batch.xt = context_target_batch.xt
    combined_batch.yt = context_target_batch.yt
    combined_batch.mask = context_target_batch.loss_mask
    
    # Test embedder forward pass
    print("\nTesting embedder forward pass...")
    try:
        with torch.no_grad():
            # Embed context
            context_emb = embedder.embed_context(combined_batch)
            print(f"✓ Context embedding shape: {context_emb.shape}")
            
            # Embed buffer
            buffer_emb = embedder.embed_buffer(combined_batch)
            print(f"✓ Buffer embedding shape: {buffer_emb.shape}")
            
            # Embed target
            target_emb = embedder.embed_target(combined_batch)
            print(f"✓ Target embedding shape: {target_emb.shape}")
            
            # Check embedding dimensions
            assert context_emb.shape[-1] == dim_model, f"Context embedding dim should be {dim_model}"
            assert buffer_emb.shape[-1] == dim_model, f"Buffer embedding dim should be {dim_model}"
            assert target_emb.shape[-1] == dim_model, f"Target embedding dim should be {dim_model}"
            
            # Check batch and sequence dimensions match
            assert context_emb.shape[0] == combined_batch.xc.shape[0], "Batch size mismatch"
            assert context_emb.shape[1] == combined_batch.xc.shape[1], "Context sequence length mismatch"
            assert buffer_emb.shape[1] == combined_batch.xb.shape[1], "Buffer sequence length mismatch"
            assert target_emb.shape[1] == combined_batch.xt.shape[1], "Target sequence length mismatch"
            
            print(f"\n✓ All embedding dimensions correct")
            
            # Test with MultiChannelMixtureGaussian
            print("\nTesting MultiChannelMixtureGaussian decoder...")
            
            # Create multi-channel decoder
            decoder = MultiChannelMixtureGaussian(
                dim_y=dim_y,
                dim_model=dim_model,
                dim_feedforward=feedforward_dim,
                num_components=num_mixture_components,
            )
            
            # Simulate transformer output
            transformer_output = torch.randn(
                combined_batch.xt.shape[0], 
                combined_batch.xt.shape[1], 
                dim_model
            )
            
            # Decoder forward pass (without mask since EEG always predicts all points)
            decoder_output = decoder(
                zt=transformer_output,
                yt=combined_batch.yt,
                loss_mask=None,  # No mask needed for EEG
            )
            
            print(f"✓ Decoder output:")
            print(f"  Means shape: {decoder_output.means.shape}")
            print(f"  Stds shape: {decoder_output.sds.shape}")
            print(f"  Weights shape: {decoder_output.weights.shape}")
            print(f"  Log likelihood shape: {decoder_output.log_likelihood.shape}")
            print(f"  Loss: {decoder_output.loss.item():.4f}")
            
            # Test sampling
            print("\nTesting sampling...")
            decoder_output_sample = decoder(
                zt=transformer_output,
                yt=None,  # No target for sampling
                num_samples=10,
            )
            
            if decoder_output_sample.samples is not None:
                print(f"✓ Samples shape: {decoder_output_sample.samples.shape}")
                print(f"  Expected: (B={combined_batch.xt.shape[0]}, T={combined_batch.xt.shape[1]}, num_samples=10, dim_y={dim_y})")
            
            print(f"\n✓ MultiChannelMixtureGaussian successfully handles {dim_y}-channel EEG data!")
        
    except Exception as e:
        print(f"\n❌ Forward pass failed: {e}")
        raise
    
    print("\n✅ ACE forward pass test complete!")


if __name__ == "__main__":
    test_eeg_dataloader()
    test_ace_forward_pass()