import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDPMScheduler, UNet2DConditionModel


class DiffusionEEGModelWithConcat(nn.Module):
    """
    Diffusion EEG model with cross-modal fusion using Concatenation
    Based on UNet2DConditionModel, but fusing features using Concat method
    """
    def __init__(self, 
                 eeg_channels=63, 
                 eeg_length=250, 
                 hidden_dim=768,
                 num_train_timesteps=1000,
                 device='cuda'):
        super().__init__()
        
        self.eeg_channels = eeg_channels
        self.eeg_length = eeg_length
        self.hidden_dim = hidden_dim
        self.device = device
        
        # Diffusion noise scheduler
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=num_train_timesteps,
            beta_start=0.0001,
            beta_end=0.02,
            beta_schedule="linear",
            trained_betas=None,
            variance_type="fixed_small",
            clip_sample=True,
            prediction_type="epsilon"
        )
        
        # Image embedding projection layer - keep this part to implement Concatenation fusion
        self.image_proj = nn.Linear(hidden_dim, 256)
        
        # Using UNet2DConditionModel, but used as intermediate layer for feature concatenation
        self.unet = UNet2DConditionModel(
            sample_size=(eeg_channels, eeg_length),  # (63, 250)
            in_channels=1,                           # Number of EEG channels
            out_channels=1,                          # Number of output channels
            layers_per_block=2,                      # Number of layers per block
            block_out_channels=(128, 256, 512, 512), # Channel numbers for each layer
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D", 
                "DownBlock2D",
                "DownBlock2D",
            ),
            up_block_types=(
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
            cross_attention_dim=hidden_dim,          # Although cross attention is not used, parameters need to be consistent
            attention_head_dim=64,                   # Attention head dimension
            use_linear_projection=True,              # Use linear projection
            norm_num_groups=32,                      # Number of groups for group normalization
        )
        
        # Additional layers required for Concat fusion
        # Used for feature concatenation after mid_block
        self.concat_proj = nn.Sequential(
            nn.Linear(256, 512),
            nn.SiLU(),
            nn.Linear(512, 512)
        )
        
    def forward(self, eeg_data, image_embedding, timesteps=None):
        """
        Forward propagation - Using Concatenation fusion
        Args:
            eeg_data: EEG signals [batch_size, 1, 63, 250]
            image_embedding: Image embedding [batch_size, 768]
            timesteps: Time steps [batch_size]
        """
        batch_size = eeg_data.shape[0]
        
        # If no time steps are provided, randomly sample
        if timesteps is None:
            timesteps = torch.randint(
                0, self.noise_scheduler.config.num_train_timesteps,
                (batch_size,), device=self.device
            ).long()
        
        # Process image embedding
        cond_embed = self.image_proj(image_embedding)
        
        # Generate noise
        noise = torch.randn_like(eeg_data)
        
        # Add noise to EEG signals
        noisy_eeg = self.noise_scheduler.add_noise(eeg_data, noise, timesteps)
        
        # Prepare conditional information - for concat method, we pass image features as encoder_hidden_states
        # But the key is that UNet internally does not use CrossAttn blocks, so these conditions are actually not used
        encoder_hidden_states = image_embedding.unsqueeze(1)
        
        # Use UNet to predict noise
        # Here, although encoder_hidden_states is passed in, due to using regular blocks instead of CrossAttn blocks,
        # these features will not be directly used
        noise_pred = self.unet(
            sample=noisy_eeg,
            timestep=timesteps,
            encoder_hidden_states=encoder_hidden_states,
            return_dict=False
        )[0]
        
        return noise_pred, noise
    
    @torch.no_grad()
    def generate_eeg(self, image_embedding, num_inference_steps=50):
        """
        Generate EEG signals from image embedding
        Args:
            image_embedding: Image embedding [batch_size, 768]
            num_inference_steps: Number of inference steps
        """
        batch_size = image_embedding.shape[0]
        
        # Set inference steps
        self.noise_scheduler.set_timesteps(num_inference_steps)
        
        # Initialize random noise
        eeg_sample = torch.randn(
            batch_size, 1, self.eeg_channels, self.eeg_length,
            device=self.device
        )
        
        # Process image embedding
        cond_embed = self.image_proj(image_embedding)
        encoder_hidden_states = image_embedding.unsqueeze(1)
        
        # Gradually denoise
        for t in self.noise_scheduler.timesteps:
            # Predict noise
            timesteps = t.expand(batch_size).to(self.device)
            
            noise_pred = self.unet(
                sample=eeg_sample,
                timestep=timesteps,
                encoder_hidden_states=encoder_hidden_states,
                return_dict=False
            )[0]
            
            # Denoise one step
            eeg_sample = self.noise_scheduler.step(
                noise_pred, t, eeg_sample
            ).prev_sample
        
        return eeg_sample


class ImageToEEGModelWithConcat(nn.Module):
    """
    Complete Image-to-EEG model with cross-modal fusion using Concatenation
    """
    def __init__(self, clip_model, diffusion_model):
        super().__init__()
        self.clip_model = clip_model
        self.diffusion_model = diffusion_model
        
    def forward(self, images, eeg_data=None, mode='train'):
        """
        Forward propagation
        Args:
            images: Input images [batch_size, 3, 224, 224]
            eeg_data: EEG data [batch_size, 1, 63, 250] (required during training)
            mode: 'train' or 'test'
        """
        # Get image embedding
        image_embedding = self.clip_model.embed_image(images).float()
        
        if mode == 'train' and eeg_data is not None:
            # Training mode: Calculate diffusion loss
            noise_pred, noise = self.diffusion_model(eeg_data, image_embedding)
            
            return {
                'noise_pred': noise_pred,
                'noise': noise,
                'image_embedding': image_embedding
            }
        else:
            # Test mode: Generate EEG signals
            generated_eeg = self.diffusion_model.generate_eeg(image_embedding)
            return {
                'generated_eeg': generated_eeg,
                'image_embedding': image_embedding
            }