"""
Unified DiT (Diffusion Transformer) Model
Supports brain signal generation for both EEG and MEG modalities
Uses Cross-Attention for cross-modal fusion
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from timm.models.vision_transformer import Attention, Mlp

from diffusion import create_diffusion
from cross_attention import MultiModalCrossAttention


# Default configuration
DEFAULT_CONFIG = {
    'eeg': {
        'channels': 63,      # Number of EEG electrodes
        'length': 250,       # Number of EEG sampling points
        'patch_size': (4, 4),
    },
    'meg': {
        'channels': 271,     # Number of MEG channels
        'length': 200,       # Number of MEG sampling points
        'patch_size': (4, 4),
    }
}


#################################################################################
#               Brain Region Configuration                                       #
#################################################################################

# EEG channel names (in order of chan_order from preprocessing_utils.py)
EEG_CHANNEL_NAMES = [
    'Fp1', 'Fp2', 'AF7', 'AF3', 'AFz', 'AF4', 'AF8', 'F7', 'F5', 'F3',
    'F1', 'F2', 'F4', 'F6', 'F8', 'FT9', 'FT7', 'FC5', 'FC3', 'FC1', 
    'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'FT10', 'T7', 'C5', 'C3', 'C1',
    'Cz', 'C2', 'C4', 'C6', 'T8', 'TP9', 'TP7', 'CP5', 'CP3', 'CP1', 
    'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'TP10', 'P7', 'P5', 'P3', 'P1',
    'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO3', 'POz', 'PO4', 'PO8',
    'O1', 'Oz', 'O2'
]

# EEG channel indices for each brain region (original order)
# Regions: 0=frontal, 1=central, 2=temporal, 3=parietal, 4=occipital
EEG_REGION_CHANNELS = {
    0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17, 18, 19, 20, 21, 22, 23],  # frontal (22 channels, includes Fp1, Fp2)
    1: [27, 28, 29, 30, 31, 32, 33, 37, 38, 39, 40, 41, 42, 43],  # central (14 channels)
    2: [15, 16, 24, 25, 26, 34, 35, 36, 44, 45],  # temporal (10 channels)
    3: [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59],  # parietal (14 channels)
    4: [60, 61, 62],  # occipital (3 channels)
}

# EEG channel count and padded channel count for each region
EEG_REGION_CONFIG = {
    0: {'size': 22, 'padded': 24},  # frontal: 22 -> 24 (includes prefrontal)
    1: {'size': 14, 'padded': 16},  # central: 14 -> 16
    2: {'size': 10, 'padded': 12},  # temporal: 10 -> 12
    3: {'size': 14, 'padded': 16},  # parietal: 14 -> 16
    4: {'size': 3, 'padded': 4},    # occipital: 3 -> 4
}
EEG_NUM_REGIONS = 5
EEG_PADDED_CHANNELS = sum(r['padded'] for r in EEG_REGION_CONFIG.values())  # 72

# MEG channel indices for each brain region (in preprocessed order)
# Regions: 0=frontal, 1=central, 2=parietal, 3=occipital, 4=temporal
MEG_REGION_CHANNELS = {
    0: list(range(24, 56)) + list(range(155, 187)) + list(range(264, 267)),  # frontal: 67
    1: list(range(0, 24)) + list(range(131, 155)) + list(range(260, 264)),   # central: 52
    2: list(range(75, 97)) + list(range(204, 226)) + [270],                   # parietal: 45
    3: list(range(56, 75)) + list(range(187, 204)) + list(range(267, 270)),   # occipital: 39
    4: list(range(97, 131)) + list(range(226, 260)),                          # temporal: 68
}

# MEG channel count and padded channel count for each region
MEG_REGION_CONFIG = {
    0: {'size': 67, 'padded': 68},  # frontal: 67 → 68
    1: {'size': 52, 'padded': 52},  # central: 52 → 52
    2: {'size': 45, 'padded': 48},  # parietal: 45 → 48
    3: {'size': 39, 'padded': 40},  # occipital: 39 → 40
    4: {'size': 68, 'padded': 68},  # temporal: 68 → 68
}
MEG_NUM_REGIONS = 5
MEG_PADDED_CHANNELS = sum(r['padded'] for r in MEG_REGION_CONFIG.values())  # 276


def get_region_config(modality):
    """Get brain region configuration"""
    if modality.lower() == 'eeg':
        return EEG_REGION_CHANNELS, EEG_REGION_CONFIG, EEG_NUM_REGIONS, EEG_PADDED_CHANNELS
    else:
        return MEG_REGION_CHANNELS, MEG_REGION_CONFIG, MEG_NUM_REGIONS, MEG_PADDED_CHANNELS


def build_reorder_indices(modality):
    """Build channel reordering indices (arranged by brain region)"""
    if modality.lower() == 'eeg':
        region_channels = EEG_REGION_CHANNELS
        num_regions = EEG_NUM_REGIONS
    else:
        region_channels = MEG_REGION_CHANNELS
        num_regions = MEG_NUM_REGIONS
    
    # Collect channel indices in brain region order
    reorder = []
    for r in range(num_regions):
        reorder.extend(region_channels[r])
    
    return reorder


def build_inverse_indices(reorder_indices):
    """Build inverse reordering indices"""
    inverse = [0] * len(reorder_indices)
    for new_idx, old_idx in enumerate(reorder_indices):
        inverse[old_idx] = new_idx
    return inverse


def build_patch_row_regions(modality, patch_size_h=4):
    """Build brain region ID for each patch row"""
    if modality.lower() == 'eeg':
        region_config = EEG_REGION_CONFIG
        num_regions = EEG_NUM_REGIONS
    else:
        region_config = MEG_REGION_CONFIG
        num_regions = MEG_NUM_REGIONS
    
    patch_row_regions = []
    for r in range(num_regions):
        num_rows = region_config[r]['padded'] // patch_size_h
        patch_row_regions.extend([r] * num_rows)
    
    return patch_row_regions


#################################################################################
#               Timestep Embedder                                                #
#################################################################################

class TimestepEmbedder(nn.Module):
    """
    Embed scalar timesteps into vector representations
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """Create sinusoidal timestep embeddings"""
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


#################################################################################
#               Brain Signal Patch Embedding                                     #
#################################################################################

class BrainSignalPatchEmbed(nn.Module):
    """
    Patch embedding layer for brain signals, supports non-square inputs
    Supports EEG [B, 1, 63, 250] and MEG [B, 1, 271, 200]
    """
    def __init__(self, signal_channels=63, signal_length=250, patch_size=(4, 4), 
                 in_channels=1, embed_dim=768):
        super().__init__()
        self.signal_channels = signal_channels
        self.signal_length = signal_length
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        
        # Padding to ensure divisibility by patch_size
        self.pad_h = (patch_size[0] - signal_channels % patch_size[0]) % patch_size[0]
        self.pad_w = (patch_size[1] - signal_length % patch_size[1]) % patch_size[1]
        
        self.padded_h = signal_channels + self.pad_h
        self.padded_w = signal_length + self.pad_w
        self.num_patches_h = self.padded_h // patch_size[0]
        self.num_patches_w = self.padded_w // patch_size[1]
        self.num_patches = self.num_patches_h * self.num_patches_w
        
        # Convolutional projection
        self.proj = nn.Conv2d(in_channels, embed_dim, 
                              kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        """
        x: [B, 1, signal_channels, signal_length]
        return: [B, num_patches, embed_dim]
        """
        # Padding
        if self.pad_h > 0 or self.pad_w > 0:
            x = F.pad(x, (0, self.pad_w, 0, self.pad_h))
        
        # [B, 1, H, W] -> [B, embed_dim, H/p, W/p]
        x = self.proj(x)
        # [B, embed_dim, H/p, W/p] -> [B, num_patches, embed_dim]
        x = x.flatten(2).transpose(1, 2)
        return x


#################################################################################
#               DiT Block with Cross-Attention                                   #
#################################################################################

class DiTBlockXAttention(nn.Module):
    """
    DiT Block with Cross-Attention
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True)
        self.cross_attn = MultiModalCrossAttention(dim=hidden_size, num_heads=num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
        self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=dropout)

    def forward(self, x, cond):
        """
        x: [B, N, D] Brain signal patch sequence
        cond: [B, N_cond, D] Conditioning features (image embedding)
        """
        # Self-attention
        x = x + self.attn(self.norm1(x))
        # Cross-attention with condition
        x = x + self.cross_attn(self.norm2(x), cond)
        # MLP
        x = x + self.mlp(self.norm3(x))
        return x


#################################################################################
#               Final Layer                                                      #
#################################################################################

class FinalLayer(nn.Module):
    """
    Final layer of DiT with AdaLN modulation
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        """
        x: [B, N, D]
        c: [B, D] Timestep condition
        """
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = self.linear(x)
        return x


#################################################################################
#               DiT Brain Signal Model                                           #
#################################################################################

class DiTBrainSignalModel(nn.Module):
    """
    DiT-based brain signal diffusion model
    Uses Cross-Attention for cross-modal fusion
    Supports both EEG and MEG modalities
    
    Position encoding: region_embed + temporal_embed
    - region_embed: Brain region embedding, tells the model which region each patch belongs to
    - temporal_embed: Temporal embedding, tells the model which time window each patch is in
    
    Channel processing:
    - Reordering: Arrange channels by brain region (channels in the same region are adjacent)
    - Padding: Pad each region's channel count to a multiple of 4, ensuring each patch row belongs to a single region
    """
    def __init__(
        self,
        modality='eeg',
        signal_channels=None,
        signal_length=None,
        patch_size=None,
        in_channels=1,
        hidden_size=768,         # DiT-B 配置
        depth=12,                # DiT-B 配置
        num_heads=12,            # DiT-B 配置
        mlp_ratio=4.0,
        dropout=0.1,
        learn_sigma=False,
    ):
        super().__init__()
        
        # 根据模态设置默认参数
        modality = modality.lower()
        config = DEFAULT_CONFIG.get(modality, DEFAULT_CONFIG['eeg'])
        
        self.signal_channels = signal_channels or config['channels']
        self.signal_length = signal_length or config['length']
        self.patch_size = patch_size or config['patch_size']
        self.modality = modality
        
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.learn_sigma = learn_sigma

        # Get brain region configuration
        region_channels, region_config, num_regions, padded_channels = get_region_config(modality)
        self.num_regions = num_regions
        self.original_channels = self.signal_channels
        self.padded_channels = padded_channels
        
        # Build channel reordering indices
        reorder_indices = build_reorder_indices(modality)
        inverse_indices = build_inverse_indices(reorder_indices)
        self.register_buffer('reorder_indices', torch.tensor(reorder_indices, dtype=torch.long))
        self.register_buffer('inverse_indices', torch.tensor(inverse_indices, dtype=torch.long))
        
        # Build brain region ID for each patch row
        patch_row_regions = build_patch_row_regions(modality, self.patch_size[0])
        self.register_buffer('patch_row_regions', torch.tensor(patch_row_regions, dtype=torch.long))
        
        # Calculate region padding
        self.region_pad_sizes = []
        for r in range(num_regions):
            original = region_config[r]['size']
            padded = region_config[r]['padded']
            self.region_pad_sizes.append(padded - original)

        # Patch embedding (using padded channel count)
        self.x_embedder = BrainSignalPatchEmbed(
            signal_channels=padded_channels,
            signal_length=self.signal_length,
            patch_size=self.patch_size,
            in_channels=in_channels,
            embed_dim=hidden_size
        )
        
        # Timestep embedding
        self.t_embedder = TimestepEmbedder(hidden_size)
        
        # Position encoding: region_embed + temporal_embed
        self.region_embed = nn.Embedding(num_regions, hidden_size)
        self.temporal_embed = nn.Embedding(self.x_embedder.num_patches_w, hidden_size)
        
        # Image condition projection (project CLIP embedding to sequence)
        self.cond_proj = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        
        # DiT blocks with Cross-Attention
        self.blocks = nn.ModuleList([
            DiTBlockXAttention(hidden_size, num_heads, mlp_ratio=mlp_ratio, dropout=dropout) 
            for _ in range(depth)
        ])
        
        # Final layer
        self.final_layer = FinalLayer(hidden_size, self.patch_size, self.out_channels)
        
        # Save parameters needed for unpatchify
        self.num_patches_h = self.x_embedder.num_patches_h
        self.num_patches_w = self.x_embedder.num_patches_w
        self.padded_h = self.x_embedder.padded_h
        self.padded_w = self.x_embedder.padded_w
        
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize position encodings
        nn.init.normal_(self.region_embed.weight, std=0.02)
        nn.init.normal_(self.temporal_embed.weight, std=0.02)
        
        # Initialize other layers
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)
        
        # Initialize patch embed
        w = self.x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.x_embedder.proj.bias, 0)
        
        # Initialize timestep embedding
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
        
        # Zero-out final layer
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def _reorder_channels(self, x):
        """Reorder channels (by brain region)"""
        return x[:, :, self.reorder_indices, :]
    
    def _inverse_reorder_channels(self, x):
        """Inverse reorder (restore original order)"""
        return x[:, :, self.inverse_indices, :]
    
    def _add_region_padding(self, x):
        """
        Add padding for each brain region to make channel count a multiple of 4
        x: [B, 1, original_channels, L] (already reordered)
        return: [B, 1, padded_channels, L]
        """
        B, C, H, W = x.shape
        device = x.device
        dtype = x.dtype
        
        # Get brain region configuration
        if self.modality == 'eeg':
            region_config = EEG_REGION_CONFIG
        else:
            region_config = MEG_REGION_CONFIG
        
        # Build padded data
        padded_x = torch.zeros(B, C, self.padded_channels, W, device=device, dtype=dtype)
        
        src_idx = 0
        dst_idx = 0
        for r in range(self.num_regions):
            original = region_config[r]['size']
            padded = region_config[r]['padded']
            
            # Copy original data
            padded_x[:, :, dst_idx:dst_idx+original, :] = x[:, :, src_idx:src_idx+original, :]
            # Padding part remains as 0
            
            src_idx += original
            dst_idx += padded
        
        return padded_x
    
    def _remove_region_padding(self, x):
        """
        Remove brain region padding
        x: [B, C, padded_channels, L]
        return: [B, C, original_channels, L]
        """
        B, C, H, W = x.shape
        device = x.device
        dtype = x.dtype
        
        # Get brain region configuration
        if self.modality == 'eeg':
            region_config = EEG_REGION_CONFIG
        else:
            region_config = MEG_REGION_CONFIG
        
        # Extract original data
        original_x = torch.zeros(B, C, self.original_channels, W, device=device, dtype=dtype)
        
        src_idx = 0
        dst_idx = 0
        for r in range(self.num_regions):
            original = region_config[r]['size']
            padded = region_config[r]['padded']
            
            # Copy original data (skip padding)
            original_x[:, :, dst_idx:dst_idx+original, :] = x[:, :, src_idx:src_idx+original, :]
            
            src_idx += padded
            dst_idx += original
        
        return original_x

    def _get_position_embed(self, device):
        """
        Get position encoding: region_embed + temporal_embed
        return: [1, N, D]
        """
        # region_ids: [0,0,...,0, 1,1,...,1, ...] each region repeated num_patches_w times
        region_ids = self.patch_row_regions.repeat_interleave(self.num_patches_w)
        region_emb = self.region_embed(region_ids)  # [N, D]
        
        # temporal_ids: [0,1,2,...,W-1, 0,1,2,...,W-1, ...] repeated num_patches_h times
        temporal_ids = torch.arange(self.num_patches_w, device=device)
        temporal_ids = temporal_ids.repeat(self.num_patches_h)
        temporal_emb = self.temporal_embed(temporal_ids)  # [N, D]
        
        # Combine
        pos_embed = region_emb + temporal_emb  # [N, D]
        return pos_embed.unsqueeze(0)  # [1, N, D]

    def unpatchify(self, x):
        """
        Restore patch sequence to brain signal
        x: [B, N, patch_h * patch_w * out_channels]
        return: [B, out_channels, signal_channels, signal_length]
        """
        c = self.out_channels
        p_h, p_w = self.patch_size
        h, w = self.num_patches_h, self.num_patches_w
        
        x = x.reshape(x.shape[0], h, w, p_h, p_w, c)
        x = torch.einsum('nhwpqc->nchpwq', x)
        x = x.reshape(x.shape[0], c, h * p_h, w * p_w)
        
        # Crop temporal dimension
        x = x[:, :, :, :self.signal_length]
        
        # Remove brain region padding
        x = self._remove_region_padding(x)
        
        # Inverse reorder (restore original channel order)
        x = self._inverse_reorder_channels(x)
        
        return x

    def forward(self, x, t, y=None, image_embedding=None):
        """
        Forward pass
        Args:
            x: [B, C, H, W] Noisy brain signal
            t: [B] Timesteps
            y: Unused, kept for compatibility
            image_embedding: [B, M, D] CLIP image embedding (when hidden_state=True: [B, 257, 768])
                            or [B, D] (when hidden_state=False: [B, 768])
        """
        # ===== Input preprocessing =====
        # 1. Reorder channels (by brain region)
        x = self._reorder_channels(x)
        
        # 2. Add brain region padding
        x = self._add_region_padding(x)
        
        # ===== Patch embedding =====
        x = self.x_embedder(x)  # [B, N, D]
        
        # ===== Position encoding =====
        pos_embed = self._get_position_embed(x.device)
        x = x + pos_embed
        
        # ===== Timestep embedding =====
        t_emb = self.t_embedder(t)  # [B, D]
        
        # ===== Process image condition =====
        if image_embedding is not None:
            if image_embedding.dim() == 2:
                image_embedding = image_embedding.unsqueeze(1)
            cond = self.cond_proj(image_embedding)
        else:
            cond = torch.zeros(x.shape[0], 1, self.hidden_size, device=x.device)
        
        # ===== DiT blocks =====
        for block in self.blocks:
            x = block(x, cond)
        
        # ===== Final layer =====
        x = self.final_layer(x, t_emb)  # [B, N, p*p*c]
        
        # ===== Unpatchify =====
        x = self.unpatchify(x)  # [B, C, H, W]
        
        return x


#################################################################################
#               Diffusion Brain Signal Model                                     #
#################################################################################

class DiffusionBrainSignalModel(nn.Module):
    """
    Complete DiT-based Diffusion brain signal model
    Encapsulates diffusion process and DiT model
    Supports both EEG and MEG modalities
    Uses DiT-XAttention-B/4 configuration by default
    """
    def __init__(self, 
                 modality='eeg',
                 signal_channels=None, 
                 signal_length=None, 
                 patch_size=None,
                 hidden_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4.0,
                 dropout=0.1,
                 num_train_timesteps=1000,
                 learn_sigma=False,
                 device='cuda'):
        super().__init__()
        
        # Set default parameters based on modality
        modality = modality.lower()
        config = DEFAULT_CONFIG.get(modality, DEFAULT_CONFIG['eeg'])
        
        self.signal_channels = signal_channels or config['channels']
        self.signal_length = signal_length or config['length']
        patch_size = patch_size or config['patch_size']
        
        self.modality = modality
        self.hidden_dim = hidden_dim
        self.device = device
        self.num_train_timesteps = num_train_timesteps
        self.learn_sigma = learn_sigma
        
        # DiT model
        self.dit = DiTBrainSignalModel(
            modality=modality,
            signal_channels=self.signal_channels,
            signal_length=self.signal_length,
            patch_size=patch_size,
            in_channels=1,
            hidden_size=hidden_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            dropout=dropout,
            learn_sigma=learn_sigma,
        )
        
        # Create diffusion process
        self.diffusion = create_diffusion(
            timestep_respacing="",  # Use all timesteps
            noise_schedule="linear",
            learn_sigma=learn_sigma,
            diffusion_steps=num_train_timesteps,
        )
        
    def forward(self, signal_data, image_embedding, timesteps=None):
        """
        Forward pass during training
        Args:
            signal_data: Brain signal [batch_size, 1, channels, length]
            image_embedding: Image embedding [batch_size, 768] or [batch_size, 257, 768]
            timesteps: Timesteps [batch_size] (randomly sampled if None)
        Returns:
            noise_pred: Predicted noise
            noise: Ground truth noise
            x0_hat: Predicted x0 from noise prediction
            timesteps: Used timesteps
        """
        batch_size = signal_data.shape[0]
        
        # Randomly sample timesteps if not provided
        if timesteps is None:
            timesteps = torch.randint(
                0, self.num_train_timesteps,
                (batch_size,), device=signal_data.device
            ).long()
        
        # Generate noise
        noise = torch.randn_like(signal_data)
        
        # Add noise using diffusion process
        noisy_signal = self.diffusion.q_sample(signal_data, timesteps, noise=noise)
        
        # Predict noise
        model_kwargs = {'image_embedding': image_embedding}
        noise_pred = self.dit(noisy_signal, timesteps, **model_kwargs)
        
        # If learn_sigma, only take the first half as noise prediction
        if self.learn_sigma:
            noise_pred = noise_pred[:, :1, :, :]
        
        # Predict x0_hat from noise prediction (for signal-level loss)
        x0_hat = self._predict_x0_from_eps(noisy_signal, timesteps, noise_pred)
        
        return noise_pred, noise, x0_hat, timesteps
    
    def _predict_x0_from_eps(self, x_t, t, eps):
        """Predict x0 from predicted noise"""
        sqrt_recip_alphas_cumprod = torch.from_numpy(
            self.diffusion.sqrt_recip_alphas_cumprod
        ).to(x_t.device).float()
        sqrt_recipm1_alphas_cumprod = torch.from_numpy(
            self.diffusion.sqrt_recipm1_alphas_cumprod
        ).to(x_t.device).float()
        
        # x0 = sqrt_recip_alphas_cumprod * x_t - sqrt_recipm1_alphas_cumprod * eps
        return (
            sqrt_recip_alphas_cumprod[t].view(-1, 1, 1, 1) * x_t
            - sqrt_recipm1_alphas_cumprod[t].view(-1, 1, 1, 1) * eps
        )
    
    @torch.no_grad()
    def generate(self, image_embedding, num_inference_steps=50, eta=0.0):
        """
        Generate brain signal from image embedding (using DDIM sampling)
        Args:
            image_embedding: Image embedding [batch_size, 768] or [batch_size, 257, 768]
            num_inference_steps: Number of inference steps
            eta: DDIM stochasticity parameter (0=fully deterministic, 1=close to DDPM)
        """
        batch_size = image_embedding.shape[0]
        device = image_embedding.device
        
        # Create diffusion process for inference
        inference_diffusion = create_diffusion(
            timestep_respacing=str(num_inference_steps),
            noise_schedule="linear",
            learn_sigma=self.learn_sigma,
            diffusion_steps=self.num_train_timesteps,
        )
        
        # Initialize random noise
        signal_sample = torch.randn(
            batch_size, 1, self.signal_channels, self.signal_length,
            device=device
        )
        
        # Set model parameters
        model_kwargs = {'image_embedding': image_embedding}
        
        # Use DDIM sampling (faster and more stable than DDPM)
        signal_sample = inference_diffusion.ddim_sample_loop(
            self.dit,
            signal_sample.shape,
            noise=signal_sample,
            clip_denoised=True,
            model_kwargs=model_kwargs,
            device=device,
            progress=False,
            eta=eta,  # 0=deterministic DDIM, 1=stochasticity close to DDPM
        )
        
        return signal_sample


#################################################################################
#               Image to Brain Signal Model                                      #
#################################################################################

class ImageToBrainSignalModel(nn.Module):
    """
    Complete Image-to-BrainSignal model
    Combines CLIP encoder and DiT diffusion model
    Supports both EEG and MEG modalities
    
    Supports concatenating CLIP image embeddings and CLIP text embeddings into unified embeddings,
    serving as Key and Value for cross-attention.
    - Image embeddings: [B, 257, 768]  (ViT-L/14 patch features)
    - Text embeddings: [B, 77, 768]    (CLIP text encoder output)
    - Unified embeddings: [B, 334, 768] (concatenated)
    """
    def __init__(self, clip_model, diffusion_model, modality='eeg'):
        super().__init__()
        self.clip_model = clip_model
        self.diffusion_model = diffusion_model
        self.modality = modality.lower()
        
    def forward(self, images, signal_data=None, captions=None, mode='train'):
        """
        Forward pass
        Args:
            images: Input images [batch_size, 3, 224, 224]
            signal_data: Brain signal data [batch_size, 1, channels, length] (required for training)
            captions: List of captions for images (optional)
            mode: 'train' or 'test'
        """
        # Get image embedding [B, 257, 768]
        image_embedding = self.clip_model.embed_image(images).float()
        
        # Get text embedding (if captions available)
        if captions is not None:
            # Call embed_text to get text embeddings [B, 77, 768]
            text_embedding = self.clip_model.embed_text(captions).float()
            
            # Concatenate image embedding and text embedding
            # [B, 257, 768] + [B, 77, 768] -> [B, 334, 768]
            unified_embedding = torch.cat([image_embedding, text_embedding], dim=1)
        else:
            # If no captions, only use image embedding
            unified_embedding = image_embedding
        
        if mode == 'train' and signal_data is not None:
            # Training mode: compute diffusion loss
            noise_pred, noise, x0_hat, timesteps = self.diffusion_model(signal_data, unified_embedding)
            
            return {
                'noise_pred': noise_pred,
                'noise': noise,
                'x0_hat': x0_hat,
                'timesteps': timesteps,
                'image_embedding': image_embedding,
                'unified_embedding': unified_embedding
            }
        else:
            # Test mode: generate brain signal
            generated_signal = self.diffusion_model.generate(unified_embedding)
            return {
                'generated_signal': generated_signal,
                'image_embedding': image_embedding,
                'unified_embedding': unified_embedding
            }


#################################################################################
#               Backward Compatible Aliases                                      #
#################################################################################

# EEG model aliases
class DiTEEGModel(DiTBrainSignalModel):
    def __init__(self, eeg_channels=63, eeg_length=250, **kwargs):
        super().__init__(
            modality='eeg',
            signal_channels=eeg_channels,
            signal_length=eeg_length,
            **kwargs
        )

class DiffusionEEGModel(DiffusionBrainSignalModel):
    def __init__(self, eeg_channels=63, eeg_length=250, **kwargs):
        super().__init__(
            modality='eeg',
            signal_channels=eeg_channels,
            signal_length=eeg_length,
            **kwargs
        )
    
    def generate_eeg(self, image_embedding, num_inference_steps=50):
        return self.generate(image_embedding, num_inference_steps)

class ImageToEEGModel(ImageToBrainSignalModel):
    def __init__(self, clip_model, diffusion_model):
        super().__init__(clip_model, diffusion_model, modality='eeg')
    
    def forward(self, images, eeg_data=None, mode='train'):
        result = super().forward(images, eeg_data, mode)
        if 'generated_signal' in result:
            result['generated_eeg'] = result.pop('generated_signal')
        return result


# MEG model aliases
class DiTMEGModel(DiTBrainSignalModel):
    def __init__(self, meg_channels=271, meg_length=200, **kwargs):
        super().__init__(
            modality='meg',
            signal_channels=meg_channels,
            signal_length=meg_length,
            **kwargs
        )

class DiffusionMEGModel(DiffusionBrainSignalModel):
    def __init__(self, meg_channels=271, meg_length=200, **kwargs):
        super().__init__(
            modality='meg',
            signal_channels=meg_channels,
            signal_length=meg_length,
            **kwargs
        )
    
    def generate_meg(self, image_embedding, num_inference_steps=50):
        return self.generate(image_embedding, num_inference_steps)

class ImageToMEGModel(ImageToBrainSignalModel):
    def __init__(self, clip_model, diffusion_model):
        super().__init__(clip_model, diffusion_model, modality='meg')
    
    def forward(self, images, meg_data=None, mode='train'):
        result = super().forward(images, meg_data, mode)
        if 'generated_signal' in result:
            result['generated_meg'] = result.pop('generated_signal')
        return result
