"""Test script for the EEG sampler with interactive visualizations."""

import torch
import numpy as np
import matplotlib.pyplot as plt
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.columns import Columns

from src.data.eeg_sampler import EEGSampler, PREDEFINED_NC_VALUES, BUFFER_SIZE

console = Console()


def print_batch_info(batch, task_mode):
    """Print detailed information about a batch."""
    console.print(f"\n[bold cyan]Task Mode: {task_mode}[/bold cyan]")
    
    # Create table for batch shapes
    table = Table(title="Batch Tensor Shapes")
    table.add_column("Tensor", style="cyan")
    table.add_column("Shape", style="green")
    table.add_column("Description", style="yellow")
    
    table.add_row("xc (context time)", str(batch.xc.shape), f"[B={batch.xc.shape[0]}, nc={batch.xc.shape[1]}, dim=1]")
    table.add_row("yc (context data)", str(batch.yc.shape), f"[B={batch.yc.shape[0]}, nc={batch.yc.shape[1]}, channels={batch.yc.shape[2]}]")
    table.add_row("xb (buffer time)", str(batch.xb.shape), f"[B={batch.xb.shape[0]}, nb={batch.xb.shape[1]}, dim=1]")
    table.add_row("yb (buffer data)", str(batch.yb.shape), f"[B={batch.yb.shape[0]}, nb={batch.yb.shape[1]}, channels={batch.yb.shape[2]}]")
    table.add_row("xt (target time)", str(batch.xt.shape), f"[B={batch.xt.shape[0]}, nt={batch.xt.shape[1]}, dim=1]")
    table.add_row("yt (target data)", str(batch.yt.shape), f"[B={batch.yt.shape[0]}, nt={batch.yt.shape[1]}, channels={batch.yt.shape[2]}]")
    
    console.print(table)
    
    # Print sample statistics
    console.print("\n[bold magenta]Sample Statistics:[/bold magenta]")
    console.print(f"Context points: {batch.xc.shape[1]}")
    console.print(f"Buffer points: {batch.xb.shape[1]} (fixed)")
    console.print(f"Target points: {batch.xt.shape[1]}")
    console.print(f"Total points: {batch.xc.shape[1] + batch.xb.shape[1] + batch.xt.shape[1]}")
    
    # Check time continuity for forecasting
    if task_mode == "forecasting":
        last_context_time = batch.xc[0, -1, 0].item()
        first_buffer_time = batch.xb[0, 0, 0].item()
        last_buffer_time = batch.xb[0, -1, 0].item()
        first_target_time = batch.xt[0, 0, 0].item()
        console.print(f"\n[yellow]Time continuity check:[/yellow]")
        console.print(f"Last context time: {last_context_time:.3f}")
        console.print(f"First buffer time: {first_buffer_time:.3f}")
        console.print(f"Last buffer time: {last_buffer_time:.3f}")
        console.print(f"First target time: {first_target_time:.3f}")


def visualize_batch(batch, task_mode, sample_idx=0):
    """Visualize a single sample from the batch."""
    fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)
    fig.suptitle(f'EEG Sample Visualization - {task_mode} Task', fontsize=16)
    
    # Extract single sample
    xc = batch.xc[sample_idx, :, 0].cpu().numpy()
    yc = batch.yc[sample_idx, :, :].cpu().numpy()
    xb = batch.xb[sample_idx, :, 0].cpu().numpy()
    yb = batch.yb[sample_idx, :, :].cpu().numpy()
    xt = batch.xt[sample_idx, :, 0].cpu().numpy()
    yt = batch.yt[sample_idx, :, :].cpu().numpy()
    
    # Get overall x-axis limits
    all_x = np.concatenate([xc, xb, xt])
    xlim = (all_x.min() - 0.01, all_x.max() + 0.01)
    
    # Plot context
    ax = axes[0]
    for i in range(yc.shape[1]):
        ax.scatter(xc, yc[:, i], alpha=0.6, s=20, label=f'Ch {i+1}')
    ax.set_title('Context Data (all channels)')
    ax.set_ylabel('Signal')
    ax.legend(loc='upper right', ncol=4)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(xlim)
    
    # Plot buffer
    ax = axes[1]
    for i in range(yb.shape[1]):
        ax.scatter(xb, yb[:, i], alpha=0.6, s=20, label=f'Ch {i+1}')
    ax.set_title('Buffer Data')
    ax.set_ylabel('Signal')
    if yb.shape[1] > 1:
        ax.legend(loc='upper right', ncol=4)
    ax.grid(True, alpha=0.3)
    ax.set_xlim(xlim)
    
    # Plot target
    ax = axes[2]
    if yt.shape[1] == 1:
        # Reconstruction task - single channel
        ax.scatter(xt, yt[:, 0], alpha=0.8, s=20, color='red', label='Target Ch')
        ax.set_title('Target Data (single channel reconstruction)')
    else:
        # Other tasks - all channels
        for i in range(yt.shape[1]):
            ax.scatter(xt, yt[:, i], alpha=0.6, s=20, label=f'Ch {i+1}')
        ax.set_title('Target Data (all channels)')
        ax.legend(loc='upper right', ncol=4)
    ax.set_xlabel('Time')
    ax.set_ylabel('Signal')
    ax.grid(True, alpha=0.3)
    ax.set_xlim(xlim)
    
    plt.tight_layout()
    return fig


def test_all_modes():
    """Test all three task modes."""
    console.print(Panel.fit("[bold green]Testing EEG Sampler - All Modes[/bold green]"))
    
    modes = ["interpolation", "forecasting", "reconstruction"]
    batch_size = 4
    total_points = 256
    
    # Find the index for nc=64 (medium-sized context)
    nc_idx = PREDEFINED_NC_VALUES.index(64)
    
    for mode in modes:
        console.print(f"\n[bold blue]{'='*60}[/bold blue]")
        console.print(f"[bold blue]Testing {mode.upper()} mode[/bold blue]")
        console.print(f"[bold blue]{'='*60}[/bold blue]")
        
        # Create sampler with fixed medium context size
        sampler = EEGSampler(
            subset="train",
            mode=mode,
            batch_size=batch_size,
            num_tasks=batch_size,  # Just one batch for testing
            total_points=total_points,
            nc_idx=nc_idx,  # Use nc=64
            device="cpu",
            seed=42
        )
        
        # Generate batch
        batch = sampler.generate_batch()
        
        # Print batch info
        print_batch_info(batch, mode)
        
        # Visualize
        fig = visualize_batch(batch, mode)
        plt.savefig(f'eeg_sample_{mode}.png', dpi=150, bbox_inches='tight')
        console.print(f"\n[green]Saved visualization to eeg_sample_{mode}.png[/green]")
        plt.close()


def test_combination_tracking():
    """Test that combination tracking works correctly."""
    console.print("\n[bold cyan]Testing Combination Tracking[/bold cyan]")
    
    sampler = EEGSampler(
        subset="train",
        mode="random",
        batch_size=16,
        num_tasks=16*20,  # 20 batches
        total_points=256,
        device="cpu"
    )
    
    # Generate several batches
    for i, batch in enumerate(sampler):
        if i >= 20:  # Just test 20 batches
            break
    
    # Show used combinations
    combinations = sampler.get_used_combinations()
    
    table = Table(title="Used (nc, nb, nt) Combinations")
    table.add_column("nc", style="cyan")
    table.add_column("nb", style="green") 
    table.add_column("nt", style="yellow")
    table.add_column("Total", style="magenta")
    
    for nc, nb, nt in combinations:
        table.add_row(str(nc), str(nb), str(nt), str(nc + nb + nt))
    
    console.print(table)
    console.print(f"\n[green]Total unique combinations used: {len(combinations)}[/green]")
    console.print(f"[yellow]This should be ≤ 20 to avoid compilation issues[/yellow]")


def test_nc_values():
    """Show all predefined NC values."""
    console.print("\n[bold magenta]Predefined NC Values (powers of 2 based):[/bold magenta]")
    
    # Sort all values
    sorted_nc = sorted(PREDEFINED_NC_VALUES)
    
    # Group by powers of 2
    powers_of_2 = [n for n in sorted_nc if n & (n-1) == 0 and n != 0]
    others = [n for n in sorted_nc if n not in powers_of_2]
    
    console.print(f"\n[cyan]Context size range: {min(sorted_nc)} to {max(sorted_nc)} points[/cyan]")
    console.print(f"[cyan]Total available sizes: {len(sorted_nc)}[/cyan]")
    console.print("\n[cyan]Pure powers of 2:[/cyan]", powers_of_2)
    console.print("[yellow]Other values:[/yellow]", others)
    
    # Show example total_points values that work well
    console.print("\n[bold]Example total_points values that work well:[/bold]")
    example_totals = [128, 256, 384, 512, 768, 1024]
    
    table = Table(title="Example Configurations")
    table.add_column("total_points", style="cyan")
    table.add_column("Example nc", style="green")
    table.add_column("nb (fixed)", style="yellow")
    table.add_column("Resulting nt", style="magenta")
    
    for total in example_totals:
        # Pick a reasonable nc
        nc = min([n for n in PREDEFINED_NC_VALUES if n < total - BUFFER_SIZE], 
                 key=lambda x: abs(x - total//3))
        nt = total - nc - BUFFER_SIZE
        table.add_row(str(total), str(nc), str(BUFFER_SIZE), str(nt))
    
    console.print(table)


def interactive_test():
    """Run an interactive test where user can specify parameters."""
    console.print("\n[bold green]Interactive EEG Sampler Test[/bold green]")
    
    # Get user input
    try:
        mode = console.input("Task mode (interpolation/forecasting/reconstruction/random) [random]: ") or "random"
        total_points = int(console.input("Total points [256]: ") or "256")
        batch_size = int(console.input("Batch size [8]: ") or "8")
        
        # Create sampler
        sampler = EEGSampler(
            subset="train",
            mode=mode,
            batch_size=batch_size,
            num_tasks=batch_size,
            total_points=total_points,
            device="cpu"
        )
        
        # Generate and show batch
        batch = sampler.generate_batch()
        print_batch_info(batch, mode if mode != "random" else "random (will vary)")
        
        # Ask if user wants visualization
        if console.input("\nVisualize batch? (y/n) [y]: ").lower() != 'n':
            fig = visualize_batch(batch, mode)
            plt.show()
            
    except KeyboardInterrupt:
        console.print("\n[red]Test interrupted[/red]")
    except Exception as e:
        console.print(f"\n[red]Error: {e}[/red]")


if __name__ == "__main__":
    # Run all tests
    test_nc_values()
    test_all_modes()
    test_combination_tracking()
    
    # Optional interactive test
    if input("\nRun interactive test? (y/n): ").lower() == 'y':
        interactive_test()
    
    console.print("\n[bold green]All tests completed![/bold green]")