# -*- coding: utf-8 -*-
"""
RoBERTa with Block-Diagonal 2D Sparse Attention for Sequence Length Comparison
Implementing block-diagonal sparsity pattern (same as 3D but without higher-order interactions)
"""

import os
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, RobertaForSequenceClassification
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Force CPU/GPU mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


class BlockDiagonalAttention(nn.Module):
    """Block-diagonal 2D sparse attention mechanism."""
    
    def __init__(self, config, block_size=64):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.block_size = block_size
        
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    
    def create_block_diagonal_mask(self, seq_len, device):
        """Create block-diagonal attention mask."""
        # Create a mask that only allows attention within blocks
        mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
        
        # Fill diagonal blocks
        num_blocks = (seq_len + self.block_size - 1) // self.block_size
        for i in range(num_blocks):
            start_idx = i * self.block_size
            end_idx = min((i + 1) * self.block_size, seq_len)
            mask[start_idx:end_idx, start_idx:end_idx] = 0.0
        
        return mask
    
    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_len = hidden_states.size()[:2]
        
        # Linear transformations
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        
        # Compute attention scores
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        
        # Apply block-diagonal sparse pattern
        block_mask = self.create_block_diagonal_mask(seq_len, hidden_states.device)
        attention_scores = attention_scores + block_mask.unsqueeze(0).unsqueeze(0)
        
        # Apply attention mask if provided
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask
        
        # Softmax to get attention probabilities
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        
        # Apply attention to values
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        
        return context_layer


class BlockDiagonalRobertaLayer(nn.Module):
    """RoBERTa layer with block-diagonal sparse attention."""
    
    def __init__(self, config, block_size=64):
        super().__init__()
        self.attention = BlockDiagonalAttention(config, block_size)
        self.attention_output = nn.Linear(config.hidden_size, config.hidden_size)
        self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        self.intermediate = nn.Linear(config.hidden_size, config.intermediate_size)
        self.output = nn.Linear(config.intermediate_size, config.hidden_size)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.output_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
    def forward(self, hidden_states, attention_mask=None):
        # Self-attention
        attention_output = self.attention(hidden_states, attention_mask)
        attention_output = self.attention_output(attention_output)
        attention_output = self.attention_dropout(attention_output)
        attention_output = self.attention_layer_norm(attention_output + hidden_states)
        
        # Feed-forward
        intermediate_output = F.gelu(self.intermediate(attention_output))
        layer_output = self.output(intermediate_output)
        layer_output = self.output_dropout(layer_output)
        layer_output = self.output_layer_norm(layer_output + attention_output)
        
        return layer_output


class BlockDiagonalRobertaModel(nn.Module):
    """RoBERTa model with block-diagonal sparse attention layers."""
    
    def __init__(self, config, block_size=64):
        super().__init__()
        self.config = config
        
        # Load pretrained embeddings
        pretrained_model = RobertaModel.from_pretrained("roberta-base")
        self.embeddings = pretrained_model.embeddings
        self.pooler = pretrained_model.pooler
        
        # Create block-diagonal encoder layers
        self.layers = nn.ModuleList([
            BlockDiagonalRobertaLayer(config, block_size) 
            for _ in range(config.num_hidden_layers)
        ])
        
    def forward(self, input_ids, attention_mask=None):
        # Embeddings
        hidden_states = self.embeddings(input_ids)
        
        # Convert attention mask for additive attention
        if attention_mask is not None:
            extended_attention_mask = attention_mask[:, None, None, :]
            extended_attention_mask = extended_attention_mask.to(dtype=hidden_states.dtype)
            extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        else:
            extended_attention_mask = None
        
        # Pass through block-diagonal encoder layers
        for layer in self.layers:
            hidden_states = layer(hidden_states, extended_attention_mask)
        
        # Pooling
        pooled_output = self.pooler(hidden_states)
        
        return hidden_states, pooled_output


class BlockDiagonalRobertaForSequenceClassification(nn.Module):
    """RoBERTa with block-diagonal sparse attention for sequence classification."""
    
    def __init__(self, config, block_size=64):
        super().__init__()
        self.roberta = BlockDiagonalRobertaModel(config, block_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        _, pooled_output = self.roberta(input_ids, attention_mask)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
        
        return type('Outputs', (), {'loss': loss, 'logits': logits})()


def load_yelp_data(max_length=512, batch_size=8, max_samples=None):
    """Load and prepare Yelp polarity dataset with specified max_length."""
    print(f"Loading Yelp polarity dataset with max_length={max_length}...")
    
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    dataset = load_dataset("yelp_polarity")
    
    # Use full dataset
    train_subset = dataset["train"]
    test_subset = dataset["test"]
    
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=max_length
        )
    
    # Tokenize data
    tokenized_train = train_subset.map(tokenize_function, batched=True)
    tokenized_val = test_subset.map(tokenize_function, batched=True)
    
    # Format for PyTorch
    tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    tokenized_val.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    
    # Create dataloaders
    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(tokenized_val, batch_size=batch_size)
    
    print(f"Train samples: {len(tokenized_train)}, Validation samples: {len(tokenized_val)}")
    return tokenizer, train_loader, val_loader


def train_model(model, train_loader, val_loader, epochs=3, learning_rate=2e-5):
    """Train the model and return metrics."""
    print(f"Training model for {epochs} epochs...")
    
    # Optimizer
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    
    # Learning rate scheduler
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)
    
    def warmup_cosine_scheduler(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        else:
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)
    
    # Training tracking
    start_time = time.time()
    best_val_acc = 0.0
    final_train_loss = 0.0
    final_val_loss = 0.0
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            
            # Track metrics
            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(outputs.logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
            
            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total
            })
        
        # Calculate epoch stats
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total
        final_train_loss = epoch_train_loss
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["label"].to(device)
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                
                val_loss += loss.item() * labels.size(0)
                preds = torch.argmax(outputs.logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        
        # Calculate validation stats
        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total
        final_val_loss = epoch_val_loss
        
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
        
        print(f"Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.3f}, Train Acc: {epoch_train_acc:.3f}")
        print(f"  Val Loss: {epoch_val_loss:.3f}, Val Acc: {epoch_val_acc:.3f}")
    
    # Calculate total training time
    training_time = time.time() - start_time
    
    # Return final metrics
    return {
        'accuracy': best_val_acc,
        'train_loss': final_train_loss,
        'val_loss': final_val_loss,
        'training_time': training_time
    }


def format_time(seconds):
    """Format time in minutes."""
    minutes = int(seconds // 60)
    return f"{minutes}m"


def run_experiment(seq_length, block_size=64, epochs=3):
    """Run experiment for a specific sequence length."""
    print(f"\n=== Running experiment with seq_length={seq_length}, block_size={block_size} ===")
    
    # Fixed parameters for 512 tokens only - using full dataset
    batch_size = 8
    max_samples = None  # Use full Yelp dataset
    epochs = 3
    
    # Load data
    tokenizer, train_loader, val_loader = load_yelp_data(
        max_length=seq_length,
        batch_size=batch_size,
        max_samples=max_samples
    )
    
    # Create model configuration
    config = RobertaConfig.from_pretrained("roberta-base")
    config.num_hidden_layers = 3  # Smaller model for faster training
    config.num_labels = 2
    config.hidden_dropout_prob = 0.1
    config.attention_probs_dropout_prob = 0.1
    
    # Create model with block-diagonal sparse attention
    model = BlockDiagonalRobertaForSequenceClassification(config, block_size=block_size)
    model.to(device)
    
    # Train model
    metrics = train_model(
        model, 
        train_loader, 
        val_loader, 
        epochs=epochs,
        learning_rate=2e-5
    )
    
    return metrics


def main():
    """Main execution function."""
    print("=== RoBERTa with Block-Diagonal 2D Sparse Attention ===\n")
    
    # Block size for block-diagonal attention (adjust as needed)
    BLOCK_SIZE = 64  # Each block will be 64x64
    
    # Sequence length to test (only 512)
    sequence_length = 512
    
    # Store results
    results = {}
    
    # Run experiment for 512 tokens only
    try:
        metrics = run_experiment(sequence_length, BLOCK_SIZE)
        results[sequence_length] = metrics
        print(f"\nResults for seq_len={sequence_length}:")
        print(f"  Accuracy: {metrics['accuracy']:.1f}")
        print(f"  Loss: {metrics['val_loss']:.2f}")
        print(f"  Training Time: {format_time(metrics['training_time'])}")
    except Exception as e:
        print(f"Error with seq_len={sequence_length}: {e}")
        results[sequence_length] = {'accuracy': 0.0, 'val_loss': float('inf'), 'training_time': 0}
    
    # Create comparison table
    print("\n" + "="*80)
    print("COMPARISON TABLE - RoBERTa with Block-Diagonal 2D Sparse Attention")
    print("="*80)
    print(f"{'Dataset':<10} {'Classes':<8} {'L':<6} {'HOBA (Ours)':<25}")
    print(f"{'':>32} {'Acc (Loss)':<12} {'Time':<8}")
    print("-"*80)
    
    if sequence_length in results:
        metrics = results[sequence_length]
        acc = metrics['accuracy'] * 100  # Convert to percentage
        loss = metrics['val_loss']
        time_str = format_time(metrics['training_time'])
        
        print(f"{'Yelp':<10} {'2':<8} {sequence_length:<6} {acc:.1f} ({loss:.2f}){'':<8} {time_str:<8}")
    else:
        print(f"{'Yelp':<10} {'2':<8} {sequence_length:<6} {'–':<20} {'–':<8}")
    
    # Save detailed results
    detailed_results = []
    if sequence_length in results:
        metrics = results[sequence_length]
        detailed_results.append({
            'sequence_length': sequence_length,
            'accuracy': metrics['accuracy'],
            'loss': metrics['val_loss'],
            'training_time_seconds': metrics['training_time'],
            'training_time_formatted': format_time(metrics['training_time']),
            'block_size': BLOCK_SIZE,
            'attention_type': 'block_diagonal_2d'
        })
    
    # Save to CSV
    df = pd.DataFrame(detailed_results)
    df.to_csv("results/metrics/yelp_block_diagonal_roberta_comparison.csv", index=False)
    
    # Save configuration
    config_info = {
        'model_type': 'RoBERTa with Block-Diagonal 2D Sparse Attention',
        'block_size': BLOCK_SIZE,
        'dataset': 'Yelp Polarity',
        'classes': 2,
        'sequence_length_tested': sequence_length,
        'model_layers': 3,
        'attention_type': 'block_diagonal_2d',
        'optimizer': 'AdamW',
        'learning_rate': 2e-5,
        'sparsity_pattern': 'block_diagonal',
        'epochs': 3
    }
    
    import json
    with open("results/metrics/yelp_block_diagonal_roberta_config.json", "w") as f:
        json.dump(config_info, f, indent=2)
    
    print(f"\nDetailed results saved to: results/metrics/yelp_block_diagonal_roberta_comparison.csv")
    print(f"Configuration saved to: results/metrics/yelp_block_diagonal_roberta_config.json")
    
    # print("\n=== Summary ===")
    # print(f"Model: RoBERTa with Block-Diagonal 2D Sparse Attention (block_size={BLOCK_SIZE})")
    # print("Dataset: Yelp Polarity (2 classes)")
    # print(f"Sequence length tested: {sequence_length}")
    # print("Epochs: 3")
    #
    # Performance summary
    if sequence_length in results:
        acc = results[sequence_length]['accuracy'] * 100
        time_taken = results[sequence_length]['training_time']
        loss = results[sequence_length]['val_loss']
        
        print(f"\nFinal Results:")
        print(f"  Accuracy: {acc:.1f}%")
        print(f"  Loss: {loss:.2f}")
        print(f"  Training Time: {format_time(time_taken)}")
        
        # Calculate sparsity for block-diagonal
        num_blocks = (sequence_length + BLOCK_SIZE - 1) // BLOCK_SIZE
        total_sparse_elements = num_blocks * (BLOCK_SIZE ** 2)
        total_elements = sequence_length ** 2
        sparsity = 1 - (total_sparse_elements / total_elements)
        print(f"  Sparsity: {sparsity:.1%}")
        
        print(f"  Block Size: {BLOCK_SIZE}x{BLOCK_SIZE}")
        print(f"  Number of Blocks: {num_blocks}")
    else:
        print("\nNo results to display due to errors during training.")


if __name__ == "__main__":
    main()
