#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Feature Alignment and Fusion Structure Diagram Generator
Based on enhanced_data_preprocessing.py implementation
Visualize the alignment mechanism between Video, EEG, and Metadata

Author: Algorithm Engineer
Date: January 12, 2025
"""

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyBboxPatch, ConnectionPatch, Rectangle
import numpy as np
from pathlib import Path
import logging

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class FeatureAlignmentDiagramGenerator:
    """
    Generate feature alignment and fusion structure diagrams
    based on enhanced_data_preprocessing.py implementation
    """
    
    def __init__(self, output_dir: str = "./feature_alignment_diagrams"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        
        # Color scheme
        self.colors = {
            'eeg': '#2E86AB',           # Blue for EEG
            'video': '#A23B72',         # Purple for Video
            'metadata': '#F18F01',      # Orange for Metadata
            'alignment': '#C73E1D',     # Red for Alignment
            'fusion': '#4CAF50',        # Green for Fusion
            'processing': '#FF9800',    # Orange for Processing
            'output': '#9C27B0',        # Purple for Output
            'background': '#F5F5F5',    # Light gray background
            'text': '#333333'           # Dark gray text
        }
        
        # Set font properties
        plt.rcParams['font.family'] = 'Arial'
        plt.rcParams['font.size'] = 10
        
    def create_data_flow_diagram(self):
        """
        Create comprehensive data flow diagram showing the complete
        feature alignment and fusion pipeline
        """
        fig, ax = plt.subplots(1, 1, figsize=(16, 12))
        ax.set_xlim(0, 16)
        ax.set_ylim(0, 12)
        ax.axis('off')
        
        # Title
        ax.text(8, 11.5, 'Video-EEG Feature Alignment and Fusion Pipeline', 
                fontsize=18, fontweight='bold', ha='center')
        
        # Stage 1: Raw Data Input
        self._draw_stage_header(ax, 1, 10.5, "Stage 1: Raw Data Input", self.colors['background'])
        
        # EEG Data
        eeg_box = self._draw_data_box(ax, 0.5, 9, 3, 1.2, "EEG Data\n(62 channels, 200Hz)\nShape: (7, 62, 104000)", 
                                     self.colors['eeg'])
        
        # Video Data  
        video_box = self._draw_data_box(ax, 4.5, 9, 3, 1.2, "Video Data\n(25 FPS, RGB)\nShape: (frames, H, W, 3)", 
                                       self.colors['video'])
        
        # Metadata
        meta_box = self._draw_data_box(ax, 8.5, 9, 3, 1.2, "Metadata\n(Demographics, Video Info)\nSubject-Video Pairs", 
                                      self.colors['metadata'])
        
        # Stage 2: Temporal Alignment
        self._draw_stage_header(ax, 1, 8, "Stage 2: Temporal Alignment", self.colors['alignment'])
        
        # Alignment process boxes
        align_box1 = self._draw_process_box(ax, 1, 6.5, 4.5, 1, "Duration Calculation\n• EEG: length/200Hz\n• Video: frames/25FPS", 
                                           self.colors['processing'])
        
        align_box2 = self._draw_process_box(ax, 6, 6.5, 4.5, 1, "Resampling\n• Target EEG: 200Hz\n• Target Video: 25FPS", 
                                           self.colors['processing'])
        
        align_box3 = self._draw_process_box(ax, 11, 6.5, 4.5, 1, "Interpolation\n• EEG: Cubic interpolation\n• Video: Frame indexing", 
                                           self.colors['processing'])
        
        # Stage 3: Feature Normalization
        self._draw_stage_header(ax, 1, 5.5, "Stage 3: Feature Normalization", self.colors['processing'])
        
        norm_box1 = self._draw_process_box(ax, 1, 4, 4.5, 1, "EEG Normalization\n• RobustScaler\n• Channel-wise scaling", 
                                          self.colors['eeg'])
        
        norm_box2 = self._draw_process_box(ax, 6, 4, 4.5, 1, "Video Normalization\n• StandardScaler\n• Pixel-wise scaling", 
                                          self.colors['video'])
        
        norm_box3 = self._draw_process_box(ax, 11, 4, 4.5, 1, "Metadata Integration\n• Subject info mapping\n• Video-EEG pairing", 
                                          self.colors['metadata'])
        
        # Stage 4: Window Creation
        self._draw_stage_header(ax, 1, 3, "Stage 4: Aligned Window Creation", self.colors['fusion'])
        
        window_box = self._draw_process_box(ax, 3, 1.5, 10, 1, 
                                           "Sliding Window Creation\n• EEG Window: 8s × 200Hz = 1600 points\n• Video Window: 8s × 25FPS = 200 frames\n• Overlap: 50% (800 points step)", 
                                           self.colors['fusion'])
        
        # Stage 5: Data Augmentation (Optional)
        aug_box = self._draw_process_box(ax, 1, 0.2, 6, 0.8, 
                                        "Data Augmentation\n• EEG: Noise, scaling, filtering\n• Video: Brightness, contrast", 
                                        self.colors['output'])
        
        output_box = self._draw_process_box(ax, 9, 0.2, 6, 0.8, 
                                           "Output\n• Aligned EEG-Video pairs\n• Ready for model training", 
                                           self.colors['output'])
        
        # Draw connections
        self._draw_connections(ax)
        
        plt.tight_layout()
        save_path = self.output_dir / "feature_alignment_pipeline.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        logger.info(f"Feature alignment pipeline diagram saved: {save_path}")
        
    def create_temporal_alignment_detail(self):
        """
        Create detailed temporal alignment mechanism diagram
        """
        fig, ax = plt.subplots(1, 1, figsize=(14, 10))
        ax.set_xlim(0, 14)
        ax.set_ylim(0, 10)
        ax.axis('off')
        
        # Title
        ax.text(7, 9.5, 'Temporal Alignment Mechanism Detail', 
                fontsize=16, fontweight='bold', ha='center')
        
        # Original data timeline
        ax.text(1, 8.5, 'Original Data:', fontsize=12, fontweight='bold')
        
        # EEG timeline
        eeg_timeline = Rectangle((1, 7.5), 10, 0.5, facecolor=self.colors['eeg'], alpha=0.7)
        ax.add_patch(eeg_timeline)
        ax.text(6, 7.75, 'EEG: 104,000 points @ 200Hz = 520s', ha='center', va='center', color='white', fontweight='bold')
        
        # Video timeline
        video_timeline = Rectangle((1, 6.8), 8, 0.5, facecolor=self.colors['video'], alpha=0.7)
        ax.add_patch(video_timeline)
        ax.text(5, 7.05, 'Video: ~13,000 frames @ 25FPS = 520s', ha='center', va='center', color='white', fontweight='bold')
        
        # Alignment process
        ax.text(1, 6, 'Alignment Process:', fontsize=12, fontweight='bold')
        
        # Step 1: Duration calculation
        step1_box = self._draw_process_box(ax, 1, 5, 3.5, 0.8, 
                                          "Step 1: Calculate Duration\ntarget_duration = min(\n  eeg_duration, video_duration)", 
                                          self.colors['processing'])
        
        # Step 2: Target points calculation
        step2_box = self._draw_process_box(ax, 5.5, 5, 3.5, 0.8, 
                                          "Step 2: Target Points\neeg_points = duration × 200\nvideo_frames = duration × 25", 
                                          self.colors['processing'])
        
        # Step 3: Resampling
        step3_box = self._draw_process_box(ax, 10, 5, 3.5, 0.8, 
                                          "Step 3: Resampling\nEEG: Cubic interpolation\nVideo: Frame indexing", 
                                          self.colors['processing'])
        
        # Aligned data timeline
        ax.text(1, 3.5, 'Aligned Data:', fontsize=12, fontweight='bold')
        
        # Aligned EEG timeline
        aligned_eeg = Rectangle((1, 2.8), 8, 0.5, facecolor=self.colors['eeg'], alpha=0.9)
        ax.add_patch(aligned_eeg)
        ax.text(5, 3.05, 'Aligned EEG: target_points @ 200Hz', ha='center', va='center', color='white', fontweight='bold')
        
        # Aligned Video timeline
        aligned_video = Rectangle((1, 2.1), 8, 0.5, facecolor=self.colors['video'], alpha=0.9)
        ax.add_patch(aligned_video)
        ax.text(5, 2.35, 'Aligned Video: target_frames @ 25FPS', ha='center', va='center', color='white', fontweight='bold')
        
        # Window creation visualization
        ax.text(1, 1.5, 'Window Creation:', fontsize=12, fontweight='bold')
        
        # Show sliding windows
        for i in range(3):
            window_start = 1 + i * 2
            # EEG window
            eeg_window = Rectangle((window_start, 0.8), 1.5, 0.3, 
                                  facecolor=self.colors['eeg'], alpha=0.8, 
                                  edgecolor='black', linewidth=1)
            ax.add_patch(eeg_window)
            ax.text(window_start + 0.75, 0.95, f'EEG W{i+1}', ha='center', va='center', 
                   color='white', fontsize=8, fontweight='bold')
            
            # Video window
            video_window = Rectangle((window_start, 0.4), 1.5, 0.3, 
                                    facecolor=self.colors['video'], alpha=0.8, 
                                    edgecolor='black', linewidth=1)
            ax.add_patch(video_window)
            ax.text(window_start + 0.75, 0.55, f'Video W{i+1}', ha='center', va='center', 
                   color='white', fontsize=8, fontweight='bold')
        
        # Add overlap indication
        ax.text(8, 0.6, '50% Overlap\nBetween Windows', ha='center', va='center', 
               fontsize=10, style='italic')
        
        plt.tight_layout()
        save_path = self.output_dir / "temporal_alignment_detail.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        logger.info(f"Temporal alignment detail diagram saved: {save_path}")
        
    def create_feature_fusion_architecture(self):
        """
        Create feature fusion architecture diagram
        """
        fig, ax = plt.subplots(1, 1, figsize=(12, 10))
        ax.set_xlim(0, 12)
        ax.set_ylim(0, 10)
        ax.axis('off')
        
        # Title
        ax.text(6, 9.5, 'Feature Fusion Architecture', 
                fontsize=16, fontweight='bold', ha='center')
        
        # Input layer
        ax.text(1, 8.5, 'Input Features:', fontsize=12, fontweight='bold')
        
        # EEG features
        eeg_feat_box = self._draw_data_box(ax, 0.5, 7.5, 3, 0.8, 
                                          "EEG Features\n(62, 1600)\nNormalized", 
                                          self.colors['eeg'])
        
        # Video features
        video_feat_box = self._draw_data_box(ax, 4.5, 7.5, 3, 0.8, 
                                            "Video Features\n(200, H, W, 3)\nNormalized", 
                                            self.colors['video'])
        
        # Metadata features
        meta_feat_box = self._draw_data_box(ax, 8.5, 7.5, 3, 0.8, 
                                           "Metadata\nSubject Info\nVideo Context", 
                                           self.colors['metadata'])
        
        # Feature processing layer
        ax.text(1, 6.5, 'Feature Processing:', fontsize=12, fontweight='bold')
        
        # EEG processing
        eeg_proc_box = self._draw_process_box(ax, 0.5, 5.5, 3, 0.8, 
                                             "EEG Processing\n• Channel selection\n• Frequency filtering", 
                                             self.colors['eeg'])
        
        # Video processing
        video_proc_box = self._draw_process_box(ax, 4.5, 5.5, 3, 0.8, 
                                               "Video Processing\n• Frame extraction\n• Spatial features", 
                                               self.colors['video'])
        
        # Metadata processing
        meta_proc_box = self._draw_process_box(ax, 8.5, 5.5, 3, 0.8, 
                                              "Metadata Processing\n• Subject encoding\n• Context embedding", 
                                              self.colors['metadata'])
        
        # Alignment layer
        ax.text(1, 4.5, 'Temporal Alignment:', fontsize=12, fontweight='bold')
        
        alignment_box = self._draw_process_box(ax, 2, 3.5, 8, 0.8, 
                                              "Temporal Alignment Module\n• Window synchronization • Time-series matching • Feature alignment", 
                                              self.colors['alignment'])
        
        # Fusion layer
        ax.text(1, 2.5, 'Feature Fusion:', fontsize=12, fontweight='bold')
        
        fusion_box = self._draw_process_box(ax, 3, 1.5, 6, 0.8, 
                                           "Multi-modal Fusion\n• Cross-attention • Feature concatenation • Joint representation", 
                                           self.colors['fusion'])
        
        # Output layer
        output_box = self._draw_data_box(ax, 4, 0.3, 4, 0.6, 
                                        "Fused Features\nReady for Model Training", 
                                        self.colors['output'])
        
        # Draw connections
        # Input to processing
        self._draw_arrow(ax, 2, 7.5, 2, 6.3)
        self._draw_arrow(ax, 6, 7.5, 6, 6.3)
        self._draw_arrow(ax, 10, 7.5, 10, 6.3)
        
        # Processing to alignment
        self._draw_arrow(ax, 2, 5.5, 4, 4.3)
        self._draw_arrow(ax, 6, 5.5, 6, 4.3)
        self._draw_arrow(ax, 10, 5.5, 8, 4.3)
        
        # Alignment to fusion
        self._draw_arrow(ax, 6, 3.5, 6, 2.3)
        
        # Fusion to output
        self._draw_arrow(ax, 6, 1.5, 6, 0.9)
        
        plt.tight_layout()
        save_path = self.output_dir / "feature_fusion_architecture.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        logger.info(f"Feature fusion architecture diagram saved: {save_path}")
        
    def create_data_augmentation_overview(self):
        """
        Create data augmentation techniques overview
        """
        fig, ax = plt.subplots(1, 1, figsize=(14, 8))
        ax.set_xlim(0, 14)
        ax.set_ylim(0, 8)
        ax.axis('off')
        
        # Title
        ax.text(7, 7.5, 'Data Augmentation Techniques', 
                fontsize=16, fontweight='bold', ha='center')
        
        # EEG Augmentation
        ax.text(1, 6.8, 'EEG Data Augmentation:', fontsize=12, fontweight='bold')
        
        eeg_aug_techniques = [
            "Time Shifting\n±0.1s offset",
            "Amplitude Scaling\n0.9x, 1.1x factors",
            "Gaussian Noise\n5% std addition",
            "Frequency Filtering\n40Hz low-pass",
            "Channel Dropout\n10% random channels"
        ]
        
        for i, technique in enumerate(eeg_aug_techniques):
            box = self._draw_process_box(ax, 0.5 + i * 2.6, 5.5, 2.4, 1, technique, self.colors['eeg'])
        
        # Video Augmentation
        ax.text(1, 4.3, 'Video Data Augmentation:', fontsize=12, fontweight='bold')
        
        video_aug_techniques = [
            "Brightness Adjust\n0.8x, 1.2x factors",
            "Contrast Adjust\n0.8x, 1.2x factors",
            "Gaussian Noise\nσ=5.0 addition",
            "Temporal Subsample\n50% frame reduction",
            "Spatial Crop\n80% center crop"
        ]
        
        for i, technique in enumerate(video_aug_techniques):
            box = self._draw_process_box(ax, 0.5 + i * 2.6, 3, 2.4, 1, technique, self.colors['video'])
        
        # Augmentation benefits
        ax.text(1, 1.8, 'Augmentation Benefits:', fontsize=12, fontweight='bold')
        
        benefits_box = self._draw_process_box(ax, 1, 0.5, 12, 1, 
                                             "• Increased dataset size (6x for EEG, 6x for Video)\n• Improved model robustness and generalization\n• Better handling of noise and variations\n• Reduced overfitting risk", 
                                             self.colors['fusion'])
        
        plt.tight_layout()
        save_path = self.output_dir / "data_augmentation_overview.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        logger.info(f"Data augmentation overview diagram saved: {save_path}")
        
    def _draw_stage_header(self, ax, x, y, text, color):
        """Draw stage header"""
        header_box = FancyBboxPatch((x-0.5, y-0.2), 14, 0.4, 
                                   boxstyle="round,pad=0.1", 
                                   facecolor=color, edgecolor='black', linewidth=1)
        ax.add_patch(header_box)
        ax.text(8, y, text, fontsize=12, fontweight='bold', ha='center', va='center')
        
    def _draw_data_box(self, ax, x, y, width, height, text, color):
        """Draw data box"""
        box = FancyBboxPatch((x, y), width, height, 
                            boxstyle="round,pad=0.1", 
                            facecolor=color, edgecolor='black', linewidth=1.5, alpha=0.8)
        ax.add_patch(box)
        ax.text(x + width/2, y + height/2, text, 
               fontsize=9, ha='center', va='center', color='white', fontweight='bold')
        return box
        
    def _draw_process_box(self, ax, x, y, width, height, text, color):
        """Draw process box"""
        box = Rectangle((x, y), width, height, 
                       facecolor=color, edgecolor='black', linewidth=1, alpha=0.7)
        ax.add_patch(box)
        ax.text(x + width/2, y + height/2, text, 
               fontsize=8, ha='center', va='center', color='white', fontweight='bold')
        return box
        
    def _draw_arrow(self, ax, x1, y1, x2, y2, color='black'):
        """Draw arrow connection"""
        arrow = patches.FancyArrowPatch((x1, y1), (x2, y2),
                                       connectionstyle="arc3", 
                                       arrowstyle='->', 
                                       mutation_scale=15, 
                                       color=color, linewidth=1.5)
        ax.add_patch(arrow)
        
    def _draw_connections(self, ax):
        """Draw connections between stages"""
        # Stage 1 to Stage 2
        self._draw_arrow(ax, 2, 9, 3, 7.5)
        self._draw_arrow(ax, 6, 9, 8, 7.5)
        self._draw_arrow(ax, 10, 9, 13, 7.5)
        
        # Stage 2 to Stage 3
        self._draw_arrow(ax, 3, 6.5, 3, 5)
        self._draw_arrow(ax, 8, 6.5, 8, 5)
        self._draw_arrow(ax, 13, 6.5, 13, 5)
        
        # Stage 3 to Stage 4
        self._draw_arrow(ax, 3, 4, 6, 2.5)
        self._draw_arrow(ax, 8, 4, 8, 2.5)
        self._draw_arrow(ax, 13, 4, 10, 2.5)
        
        # Stage 4 to Stage 5
        self._draw_arrow(ax, 6, 1.5, 4, 1)
        self._draw_arrow(ax, 10, 1.5, 12, 1)
        
    def generate_all_diagrams(self):
        """
        Generate all feature alignment diagrams
        """
        logger.info("Starting feature alignment diagram generation...")
        
        # Generate all diagrams
        self.create_data_flow_diagram()
        self.create_temporal_alignment_detail()
        self.create_feature_fusion_architecture()
        self.create_data_augmentation_overview()
        
        # Create README
        self._create_readme()
        
        logger.info("All feature alignment diagrams generated successfully!")
        logger.info(f"Diagrams saved to: {self.output_dir}")
        
    def _create_readme(self):
        """Create README file for the diagrams"""
        readme_content = """# Feature Alignment and Fusion Diagrams

This directory contains visualization diagrams for the video-EEG feature alignment and fusion pipeline based on `enhanced_data_preprocessing.py`.

## Diagram Descriptions

### 1. feature_alignment_pipeline.png - Complete Pipeline Overview

**Description**: Shows the complete feature alignment and fusion pipeline from raw data to training-ready output

**Stages**:
1. **Raw Data Input**: EEG (62 channels, 200Hz), Video (25 FPS), Metadata
2. **Temporal Alignment**: Duration calculation, resampling, interpolation
3. **Feature Normalization**: RobustScaler for EEG, StandardScaler for Video
4. **Aligned Window Creation**: 8-second sliding windows with 50% overlap
5. **Data Augmentation**: Optional enhancement techniques

### 2. temporal_alignment_detail.png - Temporal Alignment Mechanism

**Description**: Detailed view of the temporal alignment process

**Features**:
- Original data timelines (EEG: 520s, Video: 520s)
- Alignment steps: duration calculation, target points, resampling
- Aligned data representation
- Sliding window creation with overlap visualization

### 3. feature_fusion_architecture.png - Feature Fusion Architecture

**Description**: Multi-modal feature fusion architecture

**Components**:
- Input features (EEG, Video, Metadata)
- Feature processing layers
- Temporal alignment module
- Multi-modal fusion layer
- Output representation

### 4. data_augmentation_overview.png - Data Augmentation Techniques

**Description**: Overview of data augmentation methods for both EEG and Video

**EEG Augmentation**:
- Time shifting, amplitude scaling, Gaussian noise
- Frequency filtering, channel dropout

**Video Augmentation**:
- Brightness/contrast adjustment, Gaussian noise
- Temporal subsampling, spatial cropping

## Technical Implementation

**Based on**: `enhanced_data_preprocessing.py`
**Key Classes**: `EnhancedVideoEEGDataProcessor`
**Core Methods**:
- `temporal_alignment()`: Time synchronization
- `feature_normalization()`: Data standardization
- `create_aligned_windows()`: Window generation
- `eeg_data_augmentation()` / `video_data_augmentation()`: Data enhancement

## Usage

These diagrams can be used for:
1. **Research Papers**: Illustrate the data preprocessing pipeline
2. **Technical Documentation**: Explain the alignment mechanism
3. **Team Communication**: Visualize the feature fusion process
4. **Educational Purposes**: Teach multi-modal data processing

## Parameters

- **EEG Sampling Rate**: 200 Hz
- **Video Frame Rate**: 25 FPS
- **Time Window**: 8.0 seconds
- **Overlap Ratio**: 50%
- **EEG Window Size**: 1600 points
- **Video Window Size**: 200 frames

---

**Author**: Algorithm Engineer  
**Date**: January 12, 2025  
**Project**: EEG2Video - Enhanced Data Preprocessing
"""
        
        readme_path = self.output_dir / "README.md"
        with open(readme_path, 'w', encoding='utf-8') as f:
            f.write(readme_content)
        
        logger.info(f"README file created: {readme_path}")

def main():
    """
    Main function to generate all feature alignment diagrams
    """
    generator = FeatureAlignmentDiagramGenerator()
    generator.generate_all_diagrams()

if __name__ == "__main__":
    main()