#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SEED-VD Data Collection Flow Diagram Generator
Mimics data collection sessions and video block diagrams from papers

Author: Algorithm Engineer
Date: January 12, 2025
"""

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyBboxPatch, Rectangle
import numpy as np
from pathlib import Path

class SEEDVDFlowDiagramGenerator:
    """
    SEED-VD Data Flow Diagram Generator
    """
    
    def __init__(self, output_dir: str = "./seed_vd_diagrams"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        
        # Set font for better compatibility
        plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'sans-serif']
        plt.rcParams['axes.unicode_minus'] = False
        
        # Color configuration
        self.colors = {
            'video_block': '#4CAF50',
            'rest_phase': '#FFC107', 
            'eeg_data': '#2196F3',
            'processing': '#FF5722',
            'generation': '#9C27B0',
            'background': '#F5F5F5',
            'text': '#333333'
        }
    
    def plot_data_collection_session(self):
        """
        Plot data collection session diagram (C)
        """
        fig, ax = plt.subplots(figsize=(16, 8))
        
        # Set background
        ax.set_facecolor(self.colors['background'])
        
        # Time axis
        total_time = 0
        block_duration = 25  # Each video block 5 minutes
        rest_duration = 0.5  # Rest 30 seconds
        
        # Draw 7 video blocks and rest phases
        y_pos = 0.5
        height = 0.3
        
        for i in range(7):
            # Video block
            video_rect = FancyBboxPatch(
                (total_time, y_pos - height/2), block_duration, height,
                boxstyle="round,pad=0.02",
                facecolor=self.colors['video_block'],
                edgecolor='black',
                linewidth=2
            )
            ax.add_patch(video_rect)
            
            # Video block label
            ax.text(total_time + block_duration/2, y_pos, f'Video Block {i+1}\n(5 video clips)',
                   ha='center', va='center', fontsize=10, fontweight='bold',
                   color='white')
            
            total_time += block_duration
            
            # Rest phase (except for the last block)
            if i < 6:
                rest_rect = FancyBboxPatch(
                    (total_time, y_pos - height/4), rest_duration, height/2,
                    boxstyle="round,pad=0.01",
                    facecolor=self.colors['rest_phase'],
                    edgecolor='black',
                    linewidth=1
                )
                ax.add_patch(rest_rect)
                
                ax.text(total_time + rest_duration/2, y_pos, 'Rest\n≥30s',
                       ha='center', va='center', fontsize=8,
                       color='black')
                
                total_time += rest_duration
        
        # EEG data collection indicator
        eeg_rect = Rectangle((0, y_pos + 0.4), total_time, 0.15,
                            facecolor=self.colors['eeg_data'], alpha=0.7)
        ax.add_patch(eeg_rect)
        ax.text(total_time/2, y_pos + 0.475, 'Continuous EEG Recording (200Hz)',
               ha='center', va='center', fontsize=12, fontweight='bold',
               color='white')
        
        # Set coordinate axes
        ax.set_xlim(-1, total_time + 1)
        ax.set_ylim(0, 1.2)
        ax.set_xlabel('Time (minutes)', fontsize=12)
        ax.set_title('(C) SEED-VD Data Collection Session Diagram', fontsize=16, fontweight='bold', pad=20)
        
        # Time ticks
        time_ticks = np.arange(0, total_time + 1, 25)
        ax.set_xticks(time_ticks)
        ax.set_xticklabels([f'{int(t/5)}' for t in time_ticks])
        
        ax.set_yticks([])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        
        plt.tight_layout()
        plt.savefig(self.output_dir / "seed_vd_data_collection_session.png", 
                   dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        print("✅ Data collection session diagram saved")
    
    def plot_video_block_structure(self):
        """
        Plot video block structure diagram (D)
        """
        fig, ax = plt.subplots(figsize=(14, 6))
        
        # Set background
        ax.set_facecolor(self.colors['background'])
        
        # Time axis parameters
        hint_duration = 3
        clip_duration = 60  # Each clip 1 minute
        total_clips = 5
        
        y_pos = 0.5
        height = 0.4
        
        # Hint phase
        hint_rect = FancyBboxPatch(
            (0, y_pos - height/2), hint_duration, height,
            boxstyle="round,pad=0.02",
            facecolor=self.colors['rest_phase'],
            edgecolor='black',
            linewidth=2
        )
        ax.add_patch(hint_rect)
        ax.text(hint_duration/2, y_pos, 'Concept Hint\n(3s)',
               ha='center', va='center', fontsize=10, fontweight='bold')
        
        current_time = hint_duration
        
        # 5 video clips
        for i in range(total_clips):
            clip_rect = FancyBboxPatch(
                (current_time, y_pos - height/2), clip_duration, height,
                boxstyle="round,pad=0.02",
                facecolor=self.colors['video_block'],
                edgecolor='black',
                linewidth=2
            )
            ax.add_patch(clip_rect)
            
            ax.text(current_time + clip_duration/2, y_pos, f'Video Clip {i+1}\n(Same Concept)',
                   ha='center', va='center', fontsize=10, fontweight='bold',
                   color='white')
            
            current_time += clip_duration
        
        # EEG data collection indicator
        eeg_rect = Rectangle((0, y_pos + 0.3), current_time, 0.1,
                            facecolor=self.colors['eeg_data'], alpha=0.7)
        ax.add_patch(eeg_rect)
        ax.text(current_time/2, y_pos + 0.35, 'Synchronized EEG Recording',
               ha='center', va='center', fontsize=11, fontweight='bold',
               color='white')
        
        # Concept identifier
        concept_rect = Rectangle((hint_duration, y_pos - 0.35), 
                                current_time - hint_duration, 0.15,
                                facecolor=self.colors['processing'], alpha=0.8)
        ax.add_patch(concept_rect)
        ax.text(hint_duration + (current_time - hint_duration)/2, y_pos - 0.275, 
               '5 Different Video Clips of Same Concept',
               ha='center', va='center', fontsize=11, fontweight='bold',
               color='white')
        
        # Set coordinate axes
        ax.set_xlim(-5, current_time + 5)
        ax.set_ylim(0, 1)
        ax.set_xlabel('Time (seconds)', fontsize=12)
        ax.set_title('(D) Video Block Structure Diagram', fontsize=16, fontweight='bold', pad=20)
        
        # Time ticks
        time_ticks = np.arange(0, current_time + 1, 60)
        ax.set_xticks(time_ticks)
        
        ax.set_yticks([])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        
        plt.tight_layout()
        plt.savefig(self.output_dir / "seed_vd_video_block_structure.png", 
                   dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        print("✅ Video block structure diagram saved")
    
    def plot_data_processing_pipeline(self):
        """
        Plot data processing pipeline diagram
        """
        fig, ax = plt.subplots(figsize=(16, 10))
        
        # Set background
        ax.set_facecolor(self.colors['background'])
        
        # Process steps
        steps = [
            {'name': 'Raw Data\n(Video + EEG)', 'color': self.colors['video_block'], 'pos': (2, 8)},
            {'name': 'Data Preprocessing\n(Segmentation, Alignment)', 'color': self.colors['processing'], 'pos': (6, 8)},
            {'name': 'Video-EEG Pairing\nCreation', 'color': self.colors['eeg_data'], 'pos': (10, 8)},
            {'name': 'Dataset Split\n(Train/Val/Test)', 'color': self.colors['generation'], 'pos': (14, 8)},
            
            {'name': 'SGGN Model\nTraining', 'color': self.colors['processing'], 'pos': (4, 5)},
            {'name': 'Model Validation\n& Optimization', 'color': self.colors['eeg_data'], 'pos': (8, 5)},
            {'name': 'Best Model\nSaving', 'color': self.colors['generation'], 'pos': (12, 5)},
            
            {'name': 'EEG Dataset\nGeneration', 'color': self.colors['video_block'], 'pos': (6, 2)},
            {'name': 'Quality Assessment\n& Validation', 'color': self.colors['processing'], 'pos': (10, 2)}
        ]
        
        # Draw step boxes
        box_width = 2.5
        box_height = 1.2
        
        for step in steps:
            x, y = step['pos']
            rect = FancyBboxPatch(
                (x - box_width/2, y - box_height/2), box_width, box_height,
                boxstyle="round,pad=0.1",
                facecolor=step['color'],
                edgecolor='black',
                linewidth=2
            )
            ax.add_patch(rect)
            
            ax.text(x, y, step['name'],
                   ha='center', va='center', fontsize=10, fontweight='bold',
                   color='white')
        
        # Draw arrow connections
        arrows = [
            # First row connections
            ((2, 8), (6, 8)),
            ((6, 8), (10, 8)),
            ((10, 8), (14, 8)),
            
            # Downward connections
            ((8, 7.4), (4, 5.6)),
            
            # Second row connections
            ((4, 5), (8, 5)),
            ((8, 5), (12, 5)),
            
            # Downward connections to generation phase
            ((8, 4.4), (6, 2.6)),
            ((10, 4.4), (10, 2.6))
        ]
        
        for start, end in arrows:
            ax.annotate('', xy=end, xytext=start,
                       arrowprops=dict(arrowstyle='->', lw=2, color='black'))
        
        # Add phase labels
        ax.text(8, 9.5, 'Data Preprocessing Phase', ha='center', va='center', 
               fontsize=14, fontweight='bold', 
               bbox=dict(boxstyle="round,pad=0.3", facecolor='lightblue', alpha=0.7))
        
        ax.text(8, 6.5, 'Model Training Phase', ha='center', va='center', 
               fontsize=14, fontweight='bold',
               bbox=dict(boxstyle="round,pad=0.3", facecolor='lightgreen', alpha=0.7))
        
        ax.text(8, 3.5, 'Data Generation Phase', ha='center', va='center', 
               fontsize=14, fontweight='bold',
               bbox=dict(boxstyle="round,pad=0.3", facecolor='lightcoral', alpha=0.7))
        
        # Set coordinate axes
        ax.set_xlim(0, 16)
        ax.set_ylim(0, 11)
        ax.set_title('SEED-VD Data Processing Pipeline', fontsize=18, fontweight='bold', pad=20)
        
        ax.set_xticks([])
        ax.set_yticks([])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        
        plt.tight_layout()
        plt.savefig(self.output_dir / "seed_vd_processing_pipeline.png", 
                   dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        print("✅ Data processing pipeline diagram saved")
    
    def plot_configuration_overview(self):
        """
        Plot configuration parameters overview
        """
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        
        # Data processing configuration
        ax1.text(0.5, 0.9, 'Data Processing Config', ha='center', va='top', 
                fontsize=16, fontweight='bold', transform=ax1.transAxes)
        
        config_text1 = """
• Subject Count: Configurable (default first 2)
• Video Count: Configurable (default first 2)
• Segment Duration: 5.0 seconds
• Samples per Video: 2
• Overlap Ratio: 0.5
• Dataset Split:
  - Training: 60%
  - Validation: 20%
  - Testing: 20%
        """
        
        ax1.text(0.1, 0.8, config_text1, ha='left', va='top', 
                fontsize=12, transform=ax1.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor=self.colors['video_block'], alpha=0.3))
        
        # Model configuration
        ax2.text(0.5, 0.9, 'Model Configuration', ha='center', va='top', 
                fontsize=16, fontweight='bold', transform=ax2.transAxes)
        
        config_text2 = """
• Model Type: SGGN (Spatial-Graph-Guided Network)
• Model Path: ./sggn_training_output/best_model.pth
• Device: Auto-detect (GPU/CPU)
• Input: Video Features
• Output: EEG Signals
• Sampling Rate: 200Hz
        """
        
        ax2.text(0.1, 0.8, config_text2, ha='left', va='top', 
                fontsize=12, transform=ax2.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor=self.colors['eeg_data'], alpha=0.3))
        
        # Generation configuration
        ax3.text(0.5, 0.9, 'Generation Config', ha='center', va='top', 
                fontsize=16, fontweight='bold', transform=ax3.transAxes)
        
        config_text3 = """
• Generated Samples: Configurable (test default 2)
• Batch Size: 1
• Dataset Split: test
• Dataset Name: seed_vd_test
• Output Format: .npy files
• Quality Assessment: Automatic
        """
        
        ax3.text(0.1, 0.8, config_text3, ha='left', va='top', 
                fontsize=12, transform=ax3.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor=self.colors['generation'], alpha=0.3))
        
        # Running modes
        ax4.text(0.5, 0.9, 'Running Modes', ha='center', va='top', 
                fontsize=16, fontweight='bold', transform=ax4.transAxes)
        
        config_text4 = """
• process: Data processing only
• generate: Data generation only
• both: Complete pipeline
• quick_test: Quick testing
• Config File: seed_vd_config.yaml
• Log Level: INFO/DEBUG/WARNING/ERROR
        """
        
        ax4.text(0.1, 0.8, config_text4, ha='left', va='top', 
                fontsize=12, transform=ax4.transAxes,
                bbox=dict(boxstyle="round,pad=0.5", facecolor=self.colors['processing'], alpha=0.3))
        
        # Remove coordinate axes
        for ax in [ax1, ax2, ax3, ax4]:
            ax.set_xticks([])
            ax.set_yticks([])
            for spine in ax.spines.values():
                spine.set_visible(False)
        
        plt.suptitle('SEED-VD Data Generation Configuration Overview', fontsize=20, fontweight='bold')
        plt.tight_layout()
        plt.savefig(self.output_dir / "seed_vd_configuration_overview.png", 
                   dpi=300, bbox_inches='tight', facecolor='white')
        plt.close()
        
        print("✅ Configuration overview diagram saved")
    
    def generate_all_diagrams(self):
        """
        Generate all diagrams
        """
        print("\n🎨 Starting SEED-VD data flow diagram generation...")
        print("=" * 60)
        
        self.plot_data_collection_session()
        self.plot_video_block_structure()
        self.plot_data_processing_pipeline()
        self.plot_configuration_overview()
        
        print("\n" + "=" * 60)
        print("✅ All diagrams generated successfully!")
        print(f"📁 Images saved to: {self.output_dir}")
        print("\n📊 Generated charts:")
        print("  • seed_vd_data_collection_session.png - Data collection session diagram")
        print("  • seed_vd_video_block_structure.png - Video block structure diagram")
        print("  • seed_vd_processing_pipeline.png - Data processing pipeline")
        print("  • seed_vd_configuration_overview.png - Configuration parameters overview")

def main():
    """
    Main function
    """
    generator = SEEDVDFlowDiagramGenerator()
    generator.generate_all_diagrams()

if __name__ == "__main__":
    main()