"""
Feature Extraction for Biosignals Caption Generation

NOTE: This is a simplified demonstration version showing the basic approach.
The full feature extraction and caption generation system with all channels,
statistics, trend detection, and natural language templates will be 
open-sourced upon paper acceptance.
"""

import numpy as np
from scipy import signal as scipy_signal
import random

from caption_config import (
    CHANNEL_CONFIG, CHANNEL_DISPLAY_NAMES, STATISTIC_NAMES,
    EEG_FREQUENCY_BANDS, SLEEP_STAGES, CAPTION_TEMPLATE,
    SAMPLING_RATE, WINDOW_DURATION, SIGNAL_LENGTH, CHANNEL_NAMES,
)


# ============================================================================
# Frequency Domain Features
# ============================================================================

def compute_band_powers(signal, fs=SAMPLING_RATE):
    """
    Compute relative power in EEG frequency bands using Welch's method.
    
    Args:
        signal: 1D signal array
        fs: Sampling frequency
        
    Returns:
        Dictionary of {band_name}_power: relative power (0-1)
    """
    # Welch's method for PSD estimation
    nperseg = min(len(signal), int(4 * fs))
    freqs, psd = scipy_signal.welch(signal, fs=fs, nperseg=nperseg)
    
    # Compute total power in analysis range (0.5-30 Hz)
    mask_total = (freqs >= 0.5) & (freqs <= 30)
    total_power = np.trapz(psd[mask_total], freqs[mask_total])
    
    if total_power == 0:
        return {f'{band}_power': 0.0 for band in EEG_FREQUENCY_BANDS}
    
    results = {}
    for band_name, (low, high) in EEG_FREQUENCY_BANDS.items():
        mask = (freqs >= low) & (freqs < high)
        band_power = np.trapz(psd[mask], freqs[mask])
        results[f'{band_name}_power'] = float(band_power / total_power)
    
    return results


# ============================================================================
# Feature Extraction
# ============================================================================

def extract_channel_features(signal, channel_name):
    """
    Extract power band features for an EEG channel.
    
    Args:
        signal: 1D signal array
        channel_name: Name of the channel
        
    Returns:
        Dictionary with statistics
    """
    config = CHANNEL_CONFIG.get(channel_name, {})
    results = {'statistics': []}
    
    # Skip if signal is all zeros
    if np.all(signal == 0):
        return results
    
    # Extract band powers
    band_powers = compute_band_powers(signal)
    
    for stat_name in config.get('statistics', []):
        value = band_powers.get(stat_name)
        if value is not None and not np.isnan(value):
            results['statistics'].append((stat_name, value))
    
    return results


# ============================================================================
# Caption Generation (Simplified - one stat per channel)
# ============================================================================

def generate_channel_caption(channel_name, features):
    """
    Generate simple caption for a channel (one statistic only).
    
    Args:
        channel_name: Name of the channel
        features: Dictionary with statistics
        
    Returns:
        Caption string or None
    """
    display_name = CHANNEL_DISPLAY_NAMES.get(channel_name, channel_name)
    stats = features.get('statistics', [])
    
    if not stats:
        return None
    
    # Select one random statistic for demonstration
    stat_name, value = random.choice(stats)
    stat_display = STATISTIC_NAMES.get(stat_name, stat_name)
    
    # Use template: "Channel has a statistic of value"
    return CAPTION_TEMPLATE.format(
        channel_name=display_name,
        statistic_name=stat_display,
        value=f"{value * 100:.1f}%"
    )


def generate_sleep_stage_caption(stage_id):
    """Generate caption for sleep stage"""
    stage_name = SLEEP_STAGES.get(stage_id, 'Unknown')
    return f"Sleep stage: {stage_name}"


# ============================================================================
# Main Caption Generation Function
# ============================================================================

def generate_captions(signals, channel_names=None, sleep_stage=None):
    """
    Generate captions for biosignal data.
    
    NOTE: This demonstration generates one statistic per channel.
    The full system generates comprehensive multi-feature captions.
    
    Args:
        signals: numpy array of shape (num_channels, signal_length) or 
                 (batch_size, num_channels, signal_length)
        channel_names: List of channel names
        sleep_stage: Sleep stage ID (optional)
        
    Returns:
        Dictionary with 'channel_captions', 'stage_caption', and 'full_caption'
    """
    # Handle batch dimension
    if signals.ndim == 3:
        return [
            generate_captions(signals[i], channel_names, sleep_stage)
            for i in range(signals.shape[0])
        ]
    
    if channel_names is None:
        channel_names = CHANNEL_NAMES
    
    channel_captions = []
    
    for i, channel_name in enumerate(channel_names):
        if i >= len(signals):
            continue
        if channel_name not in CHANNEL_CONFIG:
            continue
        
        signal = signals[i]
        features = extract_channel_features(signal, channel_name)
        caption = generate_channel_caption(channel_name, features)
        
        if caption:
            channel_captions.append(caption)
    
    # Sleep stage caption
    stage_caption = None
    if sleep_stage is not None:
        stage_caption = generate_sleep_stage_caption(sleep_stage)
    
    # Combine captions
    all_captions = channel_captions.copy()
    if stage_caption:
        all_captions.append(stage_caption)
    
    full_caption = " ".join(all_captions)
    
    return {
        'channel_captions': channel_captions,
        'stage_caption': stage_caption,
        'full_caption': full_caption,
    }


# ============================================================================
# Example Usage
# ============================================================================

if __name__ == '__main__':
    np.random.seed(42)
    
    # Create synthetic EEG signals
    num_channels = 6
    signals = np.random.randn(num_channels, SIGNAL_LENGTH)
    
    # Generate captions
    result = generate_captions(
        signals,
        channel_names=['EEG_C3', 'EEG_C4', 'EEG_F3', 'EEG_F4', 'EEG_O1', 'EEG_O2'],
        sleep_stage=2,
    )
    
    print("Channel Captions:")
    for cap in result['channel_captions']:
        print(f"  {cap}")
    
    print(f"\n{result['stage_caption']}")
    print(f"\nFull Caption:\n{result['full_caption']}")
