"""
Training Script for Biosignals-Text Contrastive Learning

Trains the model with three objectives:
1. Contrastive loss (CLIP-style) for biosignal-text alignment
2. Captioning loss (cross-entropy) for text generation
3. Reconstruction loss (MSE) for self-supervised signal reconstruction
"""

import argparse
import logging
import os
import random
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

from biosignals_model import BiosignalsTextModel, BiosignalsCfg, TextCfg, DecoderCfg


# ============================================================================
# Loss Functions
# ============================================================================

def contrastive_loss(
    biosignal_features: torch.Tensor,
    text_features: torch.Tensor,
    logit_scale: torch.Tensor,
) -> torch.Tensor:
    """
    CLIP-style contrastive loss for biosignal-text alignment.
    
    Args:
        biosignal_features: (B, D) normalized biosignal embeddings
        text_features: (B, D) normalized text embeddings
        logit_scale: Learnable temperature parameter (exp of log scale)
        
    Returns:
        Scalar loss value
    """
    # Compute similarity matrix
    logits = logit_scale * biosignal_features @ text_features.T
    
    # Labels are identity (diagonal)
    labels = torch.arange(len(logits), device=logits.device)
    
    # Cross-entropy loss in both directions
    loss_i = F.cross_entropy(logits, labels)
    loss_t = F.cross_entropy(logits.T, labels)
    
    return (loss_i + loss_t) / 2


def captioning_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    pad_id: int = 0,
) -> torch.Tensor:
    """
    Cross-entropy loss for caption generation.
    
    Args:
        logits: (B, seq_len, vocab_size) model predictions
        labels: (B, seq_len) ground truth token ids
        pad_id: Padding token id to ignore
        
    Returns:
        Scalar loss value
    """
    # Flatten for cross-entropy
    logits_flat = logits.reshape(-1, logits.shape[-1])
    labels_flat = labels.reshape(-1)
    
    return F.cross_entropy(logits_flat, labels_flat, ignore_index=pad_id)


def reconstruction_loss(
    reconstructed: torch.Tensor,
    original: torch.Tensor,
    loss_type: str = 'mse',
) -> torch.Tensor:
    """
    Signal reconstruction loss (MAE-style self-supervision).
    
    Args:
        reconstructed: (B, channels, length) reconstructed signal
        original: (B, channels, length) original signal
        loss_type: 'mse' or 'mae'
        
    Returns:
        Scalar loss value
    """
    if loss_type == 'mse':
        return F.mse_loss(reconstructed, original)
    elif loss_type == 'mae':
        return F.l1_loss(reconstructed, original)
    else:
        raise ValueError(f"Unknown loss_type: {loss_type}")


# ============================================================================
# Dataset (Example Implementation)
# ============================================================================

class BiosignalsTextDataset(Dataset):
    """
    Example dataset for biosignal-text pairs.
    
    Expected data format:
    - signals: numpy array of shape (N, channels, length)
    - captions: list of tokenized captions, each (seq_len,)
    
    Replace this with your actual data loading logic.
    """
    
    def __init__(
        self,
        signals_path: str,
        captions_path: str,
        max_length: int = 256,
    ):
        """
        Args:
            signals_path: Path to numpy file with signals
            captions_path: Path to numpy file with tokenized captions
            max_length: Maximum caption length
        """
        self.max_length = max_length
        
        # Load data
        self.signals = np.load(signals_path)
        self.captions = np.load(captions_path, allow_pickle=True)
        
        assert len(self.signals) == len(self.captions), \
            "Signals and captions must have same length"
        
    def __len__(self):
        return len(self.signals)
    
    def __getitem__(self, idx):
        signal = torch.tensor(self.signals[idx], dtype=torch.float32)
        caption = torch.tensor(self.captions[idx], dtype=torch.long)
        
        # Pad or truncate caption
        if len(caption) > self.max_length:
            caption = caption[:self.max_length]
        elif len(caption) < self.max_length:
            padding = torch.zeros(self.max_length - len(caption), dtype=torch.long)
            caption = torch.cat([caption, padding])
            
        return signal, caption


class SyntheticDataset(Dataset):
    """Synthetic dataset for testing the training pipeline."""
    
    def __init__(
        self,
        num_samples: int = 1000,
        num_channels: int = 12,
        signal_length: int = 3840,
        vocab_size: int = 50257,
        max_length: int = 77,
    ):
        self.num_samples = num_samples
        self.num_channels = num_channels
        self.signal_length = signal_length
        self.vocab_size = vocab_size
        self.max_length = max_length
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Random signal
        signal = torch.randn(self.num_channels, self.signal_length)
        
        # Random caption (random tokens with BOS/EOS structure)
        caption_len = random.randint(10, self.max_length - 2)
        caption = torch.randint(3, self.vocab_size, (caption_len,))
        caption = torch.cat([
            torch.tensor([1]),  # BOS
            caption,
            torch.tensor([2]),  # EOS
        ])
        
        # Pad to max_length
        if len(caption) < self.max_length:
            padding = torch.zeros(self.max_length - len(caption), dtype=torch.long)
            caption = torch.cat([caption, padding])
        else:
            caption = caption[:self.max_length]
            
        return signal, caption


# ============================================================================
# Training Loop
# ============================================================================

def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler,
    device: torch.device,
    epoch: int,
    args,
) -> dict:
    """
    Train for one epoch.
    
    Args:
        model: The model to train
        dataloader: Training data loader
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        device: Device to train on
        epoch: Current epoch number
        args: Training arguments
        
    Returns:
        Dictionary with loss values
    """
    model.train()
    
    total_loss = 0
    total_contrastive = 0
    total_caption = 0
    total_recon = 0
    
    for batch_idx, (signals, captions) in enumerate(dataloader):
        signals = signals.to(device)
        captions = captions.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(signals, captions, output_labels=True)
        
        # Compute losses
        loss_contrastive = contrastive_loss(
            outputs['biosignal_features'],
            outputs['text_features'],
            outputs['logit_scale'],
        )
        
        loss_caption = captioning_loss(
            outputs['logits'],
            outputs['labels'],
            pad_id=0,
        )
        
        # Reconstruction loss (if enabled)
        loss_recon = torch.tensor(0.0, device=device)
        if 'reconstructed_signal' in outputs:
            loss_recon = reconstruction_loss(
                outputs['reconstructed_signal'],
                outputs['original_signal'],
                loss_type=args.recon_loss_type,
            )
        
        # Combined loss with weights
        loss = (
            args.contrastive_weight * loss_contrastive +
            args.caption_weight * loss_caption +
            args.recon_weight * loss_recon
        )
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        if args.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        
        optimizer.step()
        scheduler.step()
        
        # Accumulate losses
        total_loss += loss.item()
        total_contrastive += loss_contrastive.item()
        total_caption += loss_caption.item()
        total_recon += loss_recon.item()
        
        # Logging
        if batch_idx % args.log_interval == 0:
            logging.info(
                f"Epoch {epoch} [{batch_idx}/{len(dataloader)}] "
                f"Loss: {loss.item():.4f} "
                f"(C: {loss_contrastive.item():.4f}, "
                f"Cap: {loss_caption.item():.4f}, "
                f"Rec: {loss_recon.item():.4f})"
            )
    
    num_batches = len(dataloader)
    return {
        'loss': total_loss / num_batches,
        'contrastive': total_contrastive / num_batches,
        'caption': total_caption / num_batches,
        'reconstruction': total_recon / num_batches,
    }


def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
    args,
) -> dict:
    """
    Evaluate the model.
    
    Args:
        model: The model to evaluate
        dataloader: Validation data loader
        device: Device to evaluate on
        args: Training arguments
        
    Returns:
        Dictionary with evaluation metrics
    """
    model.eval()
    
    total_loss = 0
    total_contrastive = 0
    total_caption = 0
    total_recon = 0
    
    with torch.no_grad():
        for signals, captions in dataloader:
            signals = signals.to(device)
            captions = captions.to(device)
            
            outputs = model(signals, captions, output_labels=True)
            
            loss_contrastive = contrastive_loss(
                outputs['biosignal_features'],
                outputs['text_features'],
                outputs['logit_scale'],
            )
            
            loss_caption = captioning_loss(
                outputs['logits'],
                outputs['labels'],
                pad_id=0,
            )
            
            loss_recon = torch.tensor(0.0, device=device)
            if 'reconstructed_signal' in outputs:
                loss_recon = reconstruction_loss(
                    outputs['reconstructed_signal'],
                    outputs['original_signal'],
                    loss_type=args.recon_loss_type,
                )
            
            loss = (
                args.contrastive_weight * loss_contrastive +
                args.caption_weight * loss_caption +
                args.recon_weight * loss_recon
            )
            
            total_loss += loss.item()
            total_contrastive += loss_contrastive.item()
            total_caption += loss_caption.item()
            total_recon += loss_recon.item()
    
    num_batches = len(dataloader)
    return {
        'loss': total_loss / num_batches,
        'contrastive': total_contrastive / num_batches,
        'caption': total_caption / num_batches,
        'reconstruction': total_recon / num_batches,
    }


# ============================================================================
# Main Training Function
# ============================================================================

def main(args):
    # Setup logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
    )
    
    # Set random seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    
    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f"Using device: {device}")
    
    # Create model
    biosignals_cfg = BiosignalsCfg(
        input_channels=args.input_channels,
        signal_length=args.signal_length,
        transformer_layers=args.encoder_layers,
        transformer_width=args.model_width,
        transformer_heads=args.num_heads,
    )
    
    text_cfg = TextCfg(
        vocab_size=args.vocab_size,
        context_length=args.max_length,
        width=args.model_width,
        heads=args.num_heads,
        layers=args.text_encoder_layers,
    )
    
    decoder_cfg = DecoderCfg(
        context_length=args.max_length,
        width=args.model_width,
        heads=args.num_heads,
        layers=args.decoder_layers,
    )
    
    model = BiosignalsTextModel(
        embed_dim=args.embed_dim,
        biosignals_cfg=biosignals_cfg,
        text_cfg=text_cfg,
        decoder_cfg=decoder_cfg,
        use_signal_decoder=args.use_signal_decoder,
    ).to(device)
    
    logging.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Create datasets
    if args.synthetic:
        # Use synthetic data for testing
        train_dataset = SyntheticDataset(
            num_samples=args.synthetic_samples,
            num_channels=args.input_channels,
            signal_length=args.signal_length,
            vocab_size=args.vocab_size,
            max_length=args.max_length,
        )
        val_dataset = SyntheticDataset(
            num_samples=args.synthetic_samples // 10,
            num_channels=args.input_channels,
            signal_length=args.signal_length,
            vocab_size=args.vocab_size,
            max_length=args.max_length,
        )
    else:
        # Load real data
        train_dataset = BiosignalsTextDataset(
            signals_path=args.train_signals,
            captions_path=args.train_captions,
            max_length=args.max_length,
        )
        val_dataset = BiosignalsTextDataset(
            signals_path=args.val_signals,
            captions_path=args.val_captions,
            max_length=args.max_length,
        )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    
    # Optimizer and scheduler
    optimizer = AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay,
        betas=(args.beta1, args.beta2),
    )
    
    total_steps = len(train_loader) * args.epochs
    scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=args.lr * 0.01)
    
    # Training loop
    best_val_loss = float('inf')
    
    for epoch in range(1, args.epochs + 1):
        logging.info(f"Starting epoch {epoch}/{args.epochs}")
        
        # Train
        train_metrics = train_one_epoch(
            model, train_loader, optimizer, scheduler, device, epoch, args
        )
        
        logging.info(
            f"Epoch {epoch} Train - "
            f"Loss: {train_metrics['loss']:.4f}, "
            f"Contrastive: {train_metrics['contrastive']:.4f}, "
            f"Caption: {train_metrics['caption']:.4f}, "
            f"Reconstruction: {train_metrics['reconstruction']:.4f}"
        )
        
        # Evaluate
        val_metrics = evaluate(model, val_loader, device, args)
        
        logging.info(
            f"Epoch {epoch} Val - "
            f"Loss: {val_metrics['loss']:.4f}, "
            f"Contrastive: {val_metrics['contrastive']:.4f}, "
            f"Caption: {val_metrics['caption']:.4f}, "
            f"Reconstruction: {val_metrics['reconstruction']:.4f}"
        )
        
        # Save best model
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            if args.output_dir:
                os.makedirs(args.output_dir, exist_ok=True)
                torch.save(
                    {
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'val_loss': val_metrics['loss'],
                    },
                    os.path.join(args.output_dir, 'best_model.pt')
                )
                logging.info(f"Saved best model (val_loss: {best_val_loss:.4f})")
        
        # Save checkpoint
        if args.output_dir and epoch % args.save_interval == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_metrics['loss'],
                },
                os.path.join(args.output_dir, f'checkpoint_epoch_{epoch}.pt')
            )
    
    logging.info("Training complete!")


def parse_args():
    parser = argparse.ArgumentParser(description='Train biosignals-text model')
    
    # Data arguments
    parser.add_argument('--train-signals', type=str, help='Path to training signals')
    parser.add_argument('--train-captions', type=str, help='Path to training captions')
    parser.add_argument('--val-signals', type=str, help='Path to validation signals')
    parser.add_argument('--val-captions', type=str, help='Path to validation captions')
    parser.add_argument('--synthetic', action='store_true', help='Use synthetic data')
    parser.add_argument('--synthetic-samples', type=int, default=1000)
    
    # Model arguments
    parser.add_argument('--input-channels', type=int, default=12)
    parser.add_argument('--signal-length', type=int, default=3840)
    parser.add_argument('--embed-dim', type=int, default=512)
    parser.add_argument('--model-width', type=int, default=768)
    parser.add_argument('--num-heads', type=int, default=12)
    parser.add_argument('--encoder-layers', type=int, default=6)
    parser.add_argument('--text-encoder-layers', type=int, default=12)
    parser.add_argument('--decoder-layers', type=int, default=6)
    parser.add_argument('--vocab-size', type=int, default=50257)
    parser.add_argument('--max-length', type=int, default=256)
    parser.add_argument('--use-signal-decoder', action='store_true', default=True)
    
    # Training arguments
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight-decay', type=float, default=0.01)
    parser.add_argument('--beta1', type=float, default=0.9)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--grad-clip', type=float, default=1.0)
    
    # Loss weights
    parser.add_argument('--contrastive-weight', type=float, default=1.0)
    parser.add_argument('--caption-weight', type=float, default=1.0)
    parser.add_argument('--recon-weight', type=float, default=0.1)
    parser.add_argument('--recon-loss-type', type=str, default='mse', choices=['mse', 'mae'])
    
    # Other arguments
    parser.add_argument('--output-dir', type=str, default='./checkpoints')
    parser.add_argument('--log-interval', type=int, default=10)
    parser.add_argument('--save-interval', type=int, default=10)
    parser.add_argument('--num-workers', type=int, default=4)
    parser.add_argument('--seed', type=int, default=42)
    
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    main(args)

