"""
Unified Utility Functions Module
Supports both EEG and MEG modalities
"""
import torch
import numpy as np
import random
import os
from typing import Optional


def seed_everything(seed: int):
    """Set random seed to ensure experiment reproducibility"""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def calculate_alignment_loss(signal_emb, img_emb, temperature=0.07):
    """
    Calculate alignment loss between brain signal embedding and image embedding
    Uses contrastive learning approach, similar to CLIP training
    
    Supports both EEG and MEG
    """
    # Normalize embeddings
    signal_emb = torch.nn.functional.normalize(signal_emb, dim=-1)
    img_emb = torch.nn.functional.normalize(img_emb, dim=-1)
    
    # Compute similarity matrix
    logits = torch.matmul(signal_emb, img_emb.t()) / temperature
    
    # Create labels (diagonal elements are positive samples)
    batch_size = signal_emb.size(0)
    labels = torch.arange(batch_size).to(signal_emb.device)
    
    # Compute cross entropy loss
    loss_signal = torch.nn.functional.cross_entropy(logits, labels)
    loss_img = torch.nn.functional.cross_entropy(logits.t(), labels)
    
    return (loss_signal + loss_img) / 2


def average_trials(signal_data):
    """
    Average multiple repeated signal trials
    
    Input: [batch_size, trials, channels, length]
    Output: [batch_size, channels, length]
    
    Applicable to EEG (4 trials) and MEG (multiple trials)
    """
    return torch.mean(signal_data, dim=1)


# Backward compatible interface aliases
def average_eeg_trials(eeg_data):
    """
    Average 4 repeated EEG signal trials (backward compatible interface)
    Input: [batch_size, 4, 63, 250]
    Output: [batch_size, 63, 250]
    """
    return average_trials(eeg_data)


def average_meg_trials(meg_data):
    """
    Average multiple repeated MEG signal trials (backward compatible interface)
    Input: [batch_size, trials, 271, 200]
    Output: [batch_size, 271, 200]
    """
    return average_trials(meg_data)


def save_checkpoint(model, optimizer, epoch, loss, path):
    """Save model checkpoint"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, path)


def load_checkpoint(model, optimizer, path, device):
    """Load model checkpoint"""
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optimizer, epoch, loss


def compute_correlation(pred_signal, true_signal):
    """
    Compute correlation between predicted signal and ground truth signal
    
    Applicable to both EEG and MEG
    """
    # Flatten tensors for correlation computation
    pred_flat = pred_signal.flatten(start_dim=1)
    true_flat = true_signal.flatten(start_dim=1)
    
    # Compute correlation coefficient for each sample
    correlations = []
    for i in range(pred_flat.size(0)):
        corr = torch.corrcoef(torch.stack([pred_flat[i], true_flat[i]]))[0, 1]
        correlations.append(corr)
    
    return torch.stack(correlations)


def compute_cosine_similarity(pred_signal, true_signal):
    """
    Compute cosine similarity between predicted signal and ground truth signal
    
    Applicable to both EEG and MEG
    
    Args:
        pred_signal: Predicted signal, shape [batch_size, channels, length] or [batch_size, 1, channels, length]
        true_signal: Ground truth signal, same shape as pred_signal
    
    Returns:
        torch.Tensor: Cosine similarity for each sample, shape [batch_size]
    """
    # Flatten tensors for cosine similarity computation
    pred_flat = pred_signal.flatten(start_dim=1)  # [batch_size, channels * length]
    true_flat = true_signal.flatten(start_dim=1)  # [batch_size, channels * length]
    
    # Normalize
    pred_norm = torch.nn.functional.normalize(pred_flat, p=2, dim=1)
    true_norm = torch.nn.functional.normalize(true_flat, p=2, dim=1)
    
    # Compute cosine similarity (dot product)
    cosine_sim = (pred_norm * true_norm).sum(dim=1)  # [batch_size]
    
    return cosine_sim


def compute_synchronization_likelihood(pred_signal, true_signal, m=10, tau=1, w1=10, w2=50, p_ref=0.05, num_channels=8):
    """
    Compute original Synchronization Likelihood
    Based on Stam & van Dijk (2002) algorithm
    
    Note: This function is not differentiable, gradients will not flow through SL computation during training
    
    Args:
        pred_signal: Predicted signal [batch_size, channels, length] or [batch_size, 1, channels, length]
        true_signal: Ground truth signal, same shape as pred_signal
        m: Embedding dimension (default 10)
        tau: Time delay (default 1)
        w1: Minimum time window, avoids autocorrelation (default 10)
        w2: Maximum time window (default 50, reduced for speed)
        p_ref: Reference probability (default 0.05)
        num_channels: Number of randomly sampled channels (default 8, to reduce computation)
    
    Returns:
        torch.Tensor: Synchronization likelihood for each sample [batch_size]
                     Value range [0, 1], 1 indicates perfect synchronization
    """
    # Handle input dimensions, convert to numpy (SL computation on CPU)
    if pred_signal.dim() == 4:
        pred_signal = pred_signal.squeeze(1)
    if true_signal.dim() == 4:
        true_signal = true_signal.squeeze(1)
    
    pred_np = pred_signal.detach().cpu().numpy()
    true_np = true_signal.detach().cpu().numpy()
    
    batch_size, channels, length = pred_np.shape
    
    # Check if signal length is sufficient
    embed_length = length - (m - 1) * tau
    if embed_length <= w2 * 2:
        return [0.0] * batch_size
    
    sl_values = []
    
    for b in range(batch_size):
        # Randomly select subset of channels to reduce computation
        if channels > num_channels:
            selected_channels = np.random.choice(channels, num_channels, replace=False)
        else:
            selected_channels = np.arange(channels)
        
        channel_sl = []
        
        for ch in selected_channels:
            pred_ch = pred_np[b, ch]  # [length]
            true_ch = true_np[b, ch]  # [length]
            
            # Time delay embedding
            pred_embedded = _time_delay_embedding(pred_ch, m, tau)
            true_embedded = _time_delay_embedding(true_ch, m, tau)
            
            # Compute single channel SL
            sl = _compute_sl_single_channel(pred_embedded, true_embedded, w1, w2, p_ref)
            channel_sl.append(sl)
        
        # Average over sampled channels
        avg_sl = np.mean(channel_sl) if channel_sl else 0.0
        sl_values.append(avg_sl)
    
    return sl_values  # Return Python list


def _time_delay_embedding(signal, m, tau):
    """
    Time delay embedding, reconstruct phase space
    
    Args:
        signal: 1D time series [length]
        m: Embedding dimension
        tau: Time delay
    
    Returns:
        embedded: [embed_length, m]
    """
    length = len(signal)
    embed_length = length - (m - 1) * tau
    
    if embed_length <= 0:
        return np.array([]).reshape(0, m)
    
    embedded = np.zeros((embed_length, m))
    for i in range(embed_length):
        for j in range(m):
            embedded[i, j] = signal[i + j * tau]
    
    return embedded


def _compute_sl_single_channel(x_embedded, y_embedded, w1, w2, p_ref):
    """
    Compute synchronization likelihood between two embedded time series
    
    Uses vectorized operations for speed optimization
    """
    embed_length = x_embedded.shape[0]
    
    if embed_length < w2 * 2:
        return 0.0
    
    # Vectorized distance matrix computation
    # Using scipy's cdist or manual computation
    x_diff = x_embedded[:, np.newaxis, :] - x_embedded[np.newaxis, :, :]  # [N, N, m]
    distances_x = np.sqrt((x_diff ** 2).sum(axis=-1))  # [N, N]
    
    y_diff = y_embedded[:, np.newaxis, :] - y_embedded[np.newaxis, :, :]
    distances_y = np.sqrt((y_diff ** 2).sum(axis=-1))
    
    # Create time window mask: exclude points where |i-j| <= w1
    time_indices = np.arange(embed_length)
    time_diff = np.abs(time_indices[:, np.newaxis] - time_indices[np.newaxis, :])
    valid_mask = time_diff > w1
    
    # Set invalid positions to infinity
    distances_x = np.where(valid_mask, distances_x, np.inf)
    distances_y = np.where(valid_mask, distances_y, np.inf)
    
    sl_sum = 0.0
    valid_points = 0
    
    # Only compute for middle points (avoid boundary effects)
    for i in range(w2, embed_length - w2):
        # Reference set: points within window [i-w2, i+w2] where |i-j| > w1
        ref_start = max(0, i - w2)
        ref_end = min(embed_length, i + w2 + 1)
        
        # Get valid distances in reference set
        ref_distances_x = distances_x[i, ref_start:ref_end]
        valid_ref = ref_distances_x[ref_distances_x != np.inf]
        
        if len(valid_ref) == 0:
            continue
        
        # Determine threshold: p_ref proportion of points considered "similar"
        threshold_idx = max(1, int(len(valid_ref) * p_ref))
        sorted_distances = np.sort(valid_ref)
        threshold = sorted_distances[min(threshold_idx - 1, len(sorted_distances) - 1)]
        
        # Find nearest neighbors in x sequence
        neighbors_mask = (distances_x[i, :] <= threshold) & (distances_x[i, :] != np.inf)
        
        if not np.any(neighbors_mask):
            continue
        
        # Compute synchronization probability for these neighbors in y sequence
        y_neighbors_mask = (distances_y[i, :] <= threshold) & (distances_y[i, :] != np.inf)
        
        # Sync = proportion of x's neighbors that are also neighbors in y
        synced = np.sum(neighbors_mask & y_neighbors_mask)
        total_neighbors = np.sum(neighbors_mask)
        
        if total_neighbors > 0:
            sl_sum += synced / total_neighbors
            valid_points += 1
    
    return sl_sum / valid_points if valid_points > 0 else 0.0


def normalize_signal(signal_data):
    """
    Normalize brain signal data
    Normalizes along the time dimension
    
    Applicable to both EEG and MEG
    """
    mean = signal_data.mean(dim=-1, keepdim=True)
    std = signal_data.std(dim=-1, keepdim=True)
    return (signal_data - mean) / (std + 1e-8)


# Backward compatible interface aliases
def normalize_eeg(eeg_data):
    """Normalize EEG data (backward compatible interface)"""
    return normalize_signal(eeg_data)


def normalize_meg(meg_data):
    """Normalize MEG data (backward compatible interface)"""
    return normalize_signal(meg_data)


def get_modality_config(modality):
    """
    Get default configuration for modality
    
    Args:
        modality: 'eeg' or 'meg'
    
    Returns:
        dict: Configuration containing channels, length, patch_size, etc.
    """
    configs = {
        'eeg': {
            'channels': 63,
            'length': 250,
            'patch_size': (4, 4),
            'default_batch_size': 16,
        },
        'meg': {
            'channels': 271,
            'length': 200,
            'patch_size': (4, 4),
            'default_batch_size': 4,
        }
    }
    
    modality = modality.lower()
    if modality not in configs:
        raise ValueError(f"Unsupported modality: {modality}, please choose 'eeg' or 'meg'")
    
    return configs[modality]


def print_model_info(model, modality):
    """
    Print model information
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\n{'='*50}")
    print(f"Model Info ({modality.upper()})")
    print(f"{'='*50}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"{'='*50}\n")


def create_experiment_name(args):
    """
    Create experiment name based on experiment parameters
    """
    name_parts = [
        f"{args.modality}",
        f"sub{args.subject}",
        f"d{args.depth}",
        f"h{args.num_heads}",
        f"dim{args.hidden_dim}",
        f"p{args.patch_size_h}x{args.patch_size_w}",
        f"{args.loss_type}",
    ]
    return "_".join(name_parts)

