#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
基于单个视频文件生成8通道脑电信号可视化
使用Video2EEG-SGGN-Diffusion模型

作者: 算法工程师
日期: 2025年1月17日
"""

import os
import sys
import cv2
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path
import logging
import argparse
from typing import Dict, List, Tuple, Optional
import warnings

warnings.filterwarnings('ignore')

# 导入模型
try:
    from video2eeg_sggn_diffusion_model import Video2EEGSGGNDiffusion, create_video2eeg_sggn_diffusion
except ImportError:
    print("警告: 无法导入SGGN模型，将使用模拟数据")
    create_video2eeg_sggn_diffusion = None

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

class VideoProcessor:
    """
    视频处理器
    """
    
    def __init__(self, target_frames: int = 250, target_size: Tuple[int, int] = (224, 224)):
        """
        初始化视频处理器
        
        Args:
            target_frames: 目标帧数
            target_size: 目标尺寸
        """
        self.target_frames = target_frames
        self.target_size = target_size
    
    def load_video(self, video_path: str) -> torch.Tensor:
        """
        加载和预处理视频
        
        Args:
            video_path: 视频文件路径
            
        Returns:
            视频张量 (1, T, C, H, W)
        """
        if not os.path.exists(video_path):
            raise FileNotFoundError(f"视频文件不存在: {video_path}")
        
        cap = cv2.VideoCapture(video_path)
        frames = []
        
        logger.info(f"正在加载视频: {video_path}")
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            # 转换颜色空间和调整大小
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, self.target_size)
            frames.append(frame)
        
        cap.release()
        
        if len(frames) == 0:
            raise ValueError(f"无法从视频文件加载帧: {video_path}")
        
        logger.info(f"加载了 {len(frames)} 帧")
        
        # 采样到目标帧数
        if len(frames) > self.target_frames:
            indices = np.linspace(0, len(frames)-1, self.target_frames, dtype=int)
            frames = [frames[i] for i in indices]
            logger.info(f"采样到 {self.target_frames} 帧")
        elif len(frames) < self.target_frames:
            # 如果帧数不足，重复最后一帧
            while len(frames) < self.target_frames:
                frames.append(frames[-1])
            logger.info(f"填充到 {self.target_frames} 帧")
        
        # 转换为tensor并归一化
        frames = np.stack(frames)
        frames = torch.from_numpy(frames).float() / 255.0
        frames = frames.permute(0, 3, 1, 2)  # (frames, channels, height, width)
        
        return frames.unsqueeze(0)  # 添加batch维度

class EEGGenerator:
    """
    EEG信号生成器
    """
    
    def __init__(self, model_path: Optional[str] = None, device: str = 'auto'):
        """
        初始化EEG生成器
        
        Args:
            model_path: 模型路径（可选）
            device: 设备类型
        """
        # 设置设备
        if device == 'auto':
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)
        
        logger.info(f"使用设备: {self.device}")
        
        # 尝试加载模型
        self.model = None
        if model_path and os.path.exists(model_path) and create_video2eeg_sggn_diffusion:
            try:
                self.model = self._load_model(model_path)
                logger.info("模型加载成功")
            except Exception as e:
                logger.warning(f"模型加载失败: {e}，将使用模拟数据")
        else:
            logger.info("未提供模型路径或模型不可用，将使用模拟数据")
    
    def _load_model(self, model_path: str) -> nn.Module:
        """
        加载模型
        
        
        Args:
            model_path: 模型路径
            
        Returns:
            加载的模型
        """
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        
        # 创建模型配置
        model_config = {
            'video_feature_dim': 512,
            'eeg_channels': 14,
            'signal_length': 250,
            'num_diffusion_steps': 1000,
            'hidden_dim': 256,
            'dropout': 0.1
        }
        
        # 创建模型
        model = create_video2eeg_sggn_diffusion(model_config)
        
        # 加载权重
        try:
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        except Exception as e:
            logger.warning(f"权重加载失败: {e}")
        
        model = model.to(self.device)
        model.eval()
        
        return model
    
    def generate_eeg(self, video_frames: torch.Tensor, num_channels: int = 8) -> np.ndarray:
        """
        生成EEG信号
        
        Args:
            video_frames: 视频帧 (1, T, C, H, W)
            num_channels: EEG通道数
            
        Returns:
            生成的EEG信号 (num_channels, signal_length)
        """
        signal_length = video_frames.shape[1]  # 使用视频帧数作为信号长度
        
        if self.model is not None:
            try:
                # 使用真实模型生成
                with torch.no_grad():
                    video_frames = video_frames.to(self.device)
                    batch_size = video_frames.shape[0]
                    
                    # 创建随机时间步
                    t = torch.randint(0, 1000, (batch_size,)).to(self.device)
                    
                    # 创建噪声EEG作为输入
                    noisy_eeg = torch.randn(
                        batch_size, 
                        14,  # 模型默认通道数
                        250  # 模型默认信号长度
                    ).to(self.device)
                    
                    # 模型推理
                    output = self.model(video_frames, noisy_eeg, t)
                    
                    # 处理输出
                    if isinstance(output, tuple):
                        predicted_noise, _, _ = output
                        generated_eeg = noisy_eeg - predicted_noise
                    else:
                        generated_eeg = output
                    
                    # 转换为numpy
                    generated_eeg = generated_eeg.cpu().numpy()[0]  # (14, 250)
                    
                    # 调整到目标通道数和信号长度
                    if generated_eeg.shape[0] >= num_channels:
                        generated_eeg = generated_eeg[:num_channels]
                    else:
                        # 如果通道数不足，重复现有通道
                        repeat_times = (num_channels + generated_eeg.shape[0] - 1) // generated_eeg.shape[0]
                        generated_eeg = np.tile(generated_eeg, (repeat_times, 1))[:num_channels]
                    
                    # 调整信号长度
                    if generated_eeg.shape[1] != signal_length:
                        # 重采样到目标长度
                        from scipy import signal as scipy_signal
                        generated_eeg_resampled = np.zeros((num_channels, signal_length))
                        for ch in range(num_channels):
                            generated_eeg_resampled[ch] = scipy_signal.resample(
                                generated_eeg[ch], signal_length
                            )
                        generated_eeg = generated_eeg_resampled
                    
                    return generated_eeg
                    
            except Exception as e:
                logger.warning(f"模型推理失败: {e}，使用模拟数据")
        
        # 使用模拟数据
        logger.info("生成模拟EEG数据")
        return self._generate_simulated_eeg(num_channels, signal_length)
    
    def _generate_simulated_eeg(self, num_channels: int, signal_length: int) -> np.ndarray:
        """
        生成模拟EEG数据
        
        Args:
            num_channels: 通道数
            signal_length: 信号长度
            
        Returns:
            模拟的EEG信号
        """
        # 创建基础时间轴
        t = np.linspace(0, signal_length/250.0, signal_length)  # 假设250Hz采样率
        
        eeg_data = np.zeros((num_channels, signal_length))
        
        # 为每个通道生成不同频率的信号
        for ch in range(num_channels):
            # 基础频率（Alpha波段：8-13Hz）
            base_freq = 8 + (ch % 6)  # 8-13Hz
            
            # 主要信号
            signal_main = np.sin(2 * np.pi * base_freq * t)
            
            # 添加谐波
            signal_harmonic = 0.3 * np.sin(2 * np.pi * base_freq * 2 * t)
            
            # 添加低频成分（Delta波段：0.5-4Hz）
            low_freq = 1 + (ch % 3)
            signal_low = 0.5 * np.sin(2 * np.pi * low_freq * t)
            
            # 添加高频成分（Beta波段：13-30Hz）
            high_freq = 15 + (ch % 10)
            signal_high = 0.2 * np.sin(2 * np.pi * high_freq * t)
            
            # 添加噪声
            noise = 0.1 * np.random.randn(signal_length)
            
            # 组合信号
            eeg_data[ch] = signal_main + signal_harmonic + signal_low + signal_high + noise
            
            # 添加通道特异性的幅度调制
            amplitude_modulation = 0.8 + 0.4 * np.sin(2 * np.pi * 0.1 * t)  # 0.1Hz调制
            eeg_data[ch] *= amplitude_modulation
            
            # 归一化到合理范围
            eeg_data[ch] = eeg_data[ch] * (10 + ch * 2)  # 不同通道不同幅度
        
        return eeg_data

class EEGVisualizer:
    """
    EEG可视化器
    """
    
    def __init__(self, sampling_rate: float = 250.0):
        """
        初始化可视化器
        
        Args:
            sampling_rate: 采样率
        """
        self.sampling_rate = sampling_rate
        
        # 定义8种不同的颜色
        self.colors = [
            '#1f77b4',  # 蓝色
            '#ff7f0e',  # 橙色
            '#2ca02c',  # 绿色
            '#d62728',  # 红色
            '#9467bd',  # 紫色
            '#8c564b',  # 棕色
            '#e377c2',  # 粉色
            '#7f7f7f'   # 灰色
        ]
    
    def plot_eeg_signals(self, eeg_data: np.ndarray, output_path: str, 
                        title: str = "Generated EEG Signals from Video"):
        """
        绘制8通道EEG信号
        
        Args:
            eeg_data: EEG数据 (num_channels, signal_length)
            output_path: 输出路径
            title: 图表标题
        """
        num_channels, signal_length = eeg_data.shape
        time_axis = np.arange(signal_length) / self.sampling_rate
        
        # 创建图形
        fig, ax = plt.subplots(figsize=(15, 10))
        
        # 计算通道间的垂直偏移
        channel_offset = np.max(np.abs(eeg_data)) * 2.5
        
        # 绘制每个通道
        for ch in range(min(num_channels, 8)):  # 最多显示8个通道
            color = self.colors[ch % len(self.colors)]
            
            # 添加垂直偏移
            signal_with_offset = eeg_data[ch] + ch * channel_offset
            
            # 绘制信号
            ax.plot(time_axis, signal_with_offset, 
                   color=color, linewidth=1.5, 
                   label=f'Channel {ch+1}', alpha=0.8)
            
            # 添加通道标签
            ax.text(-0.02 * time_axis[-1], ch * channel_offset, 
                   f'Ch{ch+1}', 
                   verticalalignment='center',
                   horizontalalignment='right',
                   color=color, fontweight='bold')
        
        # 设置图表属性
        ax.set_xlabel('Time (s)', fontsize=12)
        ax.set_ylabel('Amplitude (μV)', fontsize=12)
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # 设置y轴刻度（隐藏，因为有偏移）
        ax.set_yticks([])
        
        # 添加图例
        ax.legend(loc='upper right', bbox_to_anchor=(1.15, 1))
        
        # 调整布局
        plt.tight_layout()
        
        # 保存图片
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"EEG可视化图已保存到: {output_path}")
    
    def plot_eeg_spectrogram(self, eeg_data: np.ndarray, output_path: str,
                           channel_idx: int = 0, title: str = "EEG Spectrogram"):
        """
        绘制EEG频谱图
        
        Args:
            eeg_data: EEG数据 (num_channels, signal_length)
            output_path: 输出路径
            channel_idx: 要显示的通道索引
            title: 图表标题
        """
        if channel_idx >= eeg_data.shape[0]:
            channel_idx = 0
        
        signal = eeg_data[channel_idx]
        
        # 计算频谱图
        plt.figure(figsize=(12, 6))
        Pxx, freqs, bins, im = plt.specgram(signal, Fs=self.sampling_rate, 
                                           NFFT=128, noverlap=64)
        plt.colorbar(im, label='Power (dB)')
        plt.ylabel('Frequency (Hz)')
        plt.xlabel('Time (s)')
        plt.title(f'{title} - Channel {channel_idx+1}')
        plt.ylim(0, 50)  # 限制频率范围到50Hz
        
        plt.tight_layout()
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"EEG频谱图已保存到: {output_path}")

def main():
    """
    主函数
    """
    parser = argparse.ArgumentParser(description='基于视频生成8通道EEG信号可视化')
    parser.add_argument('--video_path', type=str, 
                       default='/data0/GYF-projects/EEG2Video/data/panda01.mp4',
                       help='视频文件路径')
    parser.add_argument('--model_path', type=str, default=None,
                       help='模型检查点路径（可选）')
    parser.add_argument('--output_dir', type=str, 
                       default='./eeg_visualization_output',
                       help='输出目录')
    parser.add_argument('--num_channels', type=int, default=8,
                       help='EEG通道数')
    parser.add_argument('--device', type=str, default='auto',
                       help='设备类型')
    
    args = parser.parse_args()
    
    try:
        # 创建输出目录
        output_dir = Path(args.output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        logger.info(f"开始处理视频: {args.video_path}")
        
        # 1. 加载视频
        video_processor = VideoProcessor()
        video_frames = video_processor.load_video(args.video_path)
        
        logger.info(f"视频帧形状: {video_frames.shape}")
        
        # 2. 生成EEG信号
        eeg_generator = EEGGenerator(args.model_path, args.device)
        eeg_data = eeg_generator.generate_eeg(video_frames, args.num_channels)
        
        logger.info(f"生成的EEG形状: {eeg_data.shape}")
        
        # 3. 创建可视化
        visualizer = EEGVisualizer()
        
        # 生成主要的EEG信号图
        video_name = Path(args.video_path).stem
        eeg_plot_path = output_dir / f"{video_name}_eeg_8channels.png"
        visualizer.plot_eeg_signals(
            eeg_data, 
            str(eeg_plot_path),
            title=f"8-Channel EEG Signals Generated from {video_name}"
        )
        
        # 生成频谱图（第一个通道）
        spectrogram_path = output_dir / f"{video_name}_spectrogram_ch1.png"
        visualizer.plot_eeg_spectrogram(
            eeg_data,
            str(spectrogram_path),
            channel_idx=0,
            title=f"EEG Spectrogram from {video_name}"
        )
        
        # 保存EEG数据
        eeg_data_path = output_dir / f"{video_name}_eeg_data.npy"
        np.save(eeg_data_path, eeg_data)
        logger.info(f"EEG数据已保存到: {eeg_data_path}")
        
        # 生成报告
        report_path = output_dir / f"{video_name}_report.txt"
        with open(report_path, 'w', encoding='utf-8') as f:
            f.write(f"EEG信号生成报告\n")
            f.write(f"=" * 50 + "\n")
            f.write(f"视频文件: {args.video_path}\n")
            f.write(f"视频帧数: {video_frames.shape[1]}\n")
            f.write(f"EEG通道数: {eeg_data.shape[0]}\n")
            f.write(f"EEG信号长度: {eeg_data.shape[1]}\n")
            f.write(f"采样率: 250 Hz\n")
            f.write(f"信号持续时间: {eeg_data.shape[1]/250.0:.2f} 秒\n")
            f.write(f"\n生成的文件:\n")
            f.write(f"- EEG信号图: {eeg_plot_path.name}\n")
            f.write(f"- 频谱图: {spectrogram_path.name}\n")
            f.write(f"- EEG数据: {eeg_data_path.name}\n")
        
        logger.info(f"报告已保存到: {report_path}")
        logger.info("处理完成！")
        
        return 0
        
    except Exception as e:
        logger.error(f"处理过程中发生错误: {e}")
        import traceback
        traceback.print_exc()
        return 1

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)