# -*- coding: utf-8 -*-
"""
RoBERTa with Dilated Sliding Window Sparse Attention for Sequence Length Comparison
Implementing dilated sliding window sparsity pattern (local + strategic long-range connections)
"""

import os
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, RobertaForSequenceClassification
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Force CPU/GPU mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


class DilatedSlidingWindowAttention(nn.Module):
    """Dilated sliding window sparse attention mechanism."""
    
    def __init__(self, config, window_size=128, dilation_rates=[1, 2, 4, 8]):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.window_size = window_size
        self.dilation_rates = dilation_rates
        
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    
    def create_dilated_sliding_window_mask(self, seq_len, device):
        """Create dilated sliding window attention mask."""
        # Initialize mask with -inf (no attention)
        mask = torch.full((seq_len, seq_len), float('-inf'), device=device)
        
        for i in range(seq_len):
            # 1. Local sliding window (dilation=1)
            local_start = max(0, i - self.window_size // 2)
            local_end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, local_start:local_end] = 0.0
            
            # 2. Dilated connections for long-range dependencies
            for dilation in self.dilation_rates[1:]:  # Skip dilation=1 (already covered above)
                # Attend to positions at regular dilated intervals
                for direction in [-1, 1]:  # Both backward and forward
                    for step in range(1, self.window_size // 2 + 1):
                        target_pos = i + direction * step * dilation
                        if 0 <= target_pos < seq_len:
                            mask[i, target_pos] = 0.0
                        else:
                            break  # Stop if we go out of bounds
        
        return mask
    
    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_len = hidden_states.size()[:2]
        
        # Linear transformations
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        
        # Compute attention scores
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        
        # Apply dilated sliding window sparse pattern
        dilated_mask = self.create_dilated_sliding_window_mask(seq_len, hidden_states.device)
        attention_scores = attention_scores + dilated_mask.unsqueeze(0).unsqueeze(0)
        
        # Apply attention mask if provided
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask
        
        # Softmax to get attention probabilities
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        
        # Apply attention to values
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        
        return context_layer


class DilatedSlidingWindowRobertaLayer(nn.Module):
    """RoBERTa layer with dilated sliding window sparse attention."""
    
    def __init__(self, config, window_size=128, dilation_rates=[1, 2, 4, 8]):
        super().__init__()
        self.attention = DilatedSlidingWindowAttention(config, window_size, dilation_rates)
        self.attention_output = nn.Linear(config.hidden_size, config.hidden_size)
        self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
        self.intermediate = nn.Linear(config.hidden_size, config.intermediate_size)
        self.output = nn.Linear(config.intermediate_size, config.hidden_size)
        self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.output_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
    def forward(self, hidden_states, attention_mask=None):
        # Self-attention
        attention_output = self.attention(hidden_states, attention_mask)
        attention_output = self.attention_output(attention_output)
        attention_output = self.attention_dropout(attention_output)
        attention_output = self.attention_layer_norm(attention_output + hidden_states)
        
        # Feed-forward
        intermediate_output = F.gelu(self.intermediate(attention_output))
        layer_output = self.output(intermediate_output)
        layer_output = self.output_dropout(layer_output)
        layer_output = self.output_layer_norm(layer_output + attention_output)
        
        return layer_output


class DilatedSlidingWindowRobertaModel(nn.Module):
    """RoBERTa model with dilated sliding window sparse attention layers."""
    
    def __init__(self, config, window_size=128, dilation_rates=[1, 2, 4, 8]):
        super().__init__()
        self.config = config
        
        # Load pretrained embeddings
        pretrained_model = RobertaModel.from_pretrained("roberta-base")
        self.embeddings = pretrained_model.embeddings
        self.pooler = pretrained_model.pooler
        
        # Create dilated sliding window encoder layers
        self.layers = nn.ModuleList([
            DilatedSlidingWindowRobertaLayer(config, window_size, dilation_rates) 
            for _ in range(config.num_hidden_layers)
        ])
        
    def forward(self, input_ids, attention_mask=None):
        # Embeddings
        hidden_states = self.embeddings(input_ids)
        
        # Convert attention mask for additive attention
        if attention_mask is not None:
            extended_attention_mask = attention_mask[:, None, None, :]
            extended_attention_mask = extended_attention_mask.to(dtype=hidden_states.dtype)
            extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        else:
            extended_attention_mask = None
        
        # Pass through dilated sliding window encoder layers
        for layer in self.layers:
            hidden_states = layer(hidden_states, extended_attention_mask)
        
        # Pooling
        pooled_output = self.pooler(hidden_states)
        
        return hidden_states, pooled_output


class DilatedSlidingWindowRobertaForSequenceClassification(nn.Module):
    """RoBERTa with dilated sliding window sparse attention for sequence classification."""
    
    def __init__(self, config, window_size=128, dilation_rates=[1, 2, 4, 8]):
        super().__init__()
        self.roberta = DilatedSlidingWindowRobertaModel(config, window_size, dilation_rates)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        
    def forward(self, input_ids, attention_mask=None, labels=None):
        _, pooled_output = self.roberta(input_ids, attention_mask)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
        
        return type('Outputs', (), {'loss': loss, 'logits': logits})()

def load_yelp_data(max_length=512, batch_size=8, train_samples=100000, val_samples=20000):
    """Load and prepare Yelp polarity dataset with specified max_length."""
    print(f"Loading Yelp polarity dataset with max_length={max_length}...")
    print(f"Using {train_samples} training samples and {val_samples} validation samples")
    
    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
    dataset = load_dataset("yelp_polarity")
    
    # Use specified subset sizes instead of full dataset
    if train_samples and train_samples < len(dataset["train"]):
        train_subset = dataset["train"].shuffle(seed=42).select(range(train_samples))
    else:
        train_subset = dataset["train"]
    
    if val_samples and val_samples < len(dataset["test"]):
        test_subset = dataset["test"].shuffle(seed=42).select(range(val_samples))
    else:
        test_subset = dataset["test"]
    
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=max_length
        )
    
    # Tokenize data
    tokenized_train = train_subset.map(tokenize_function, batched=True)
    tokenized_val = test_subset.map(tokenize_function, batched=True)
    
    # Format for PyTorch
    tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    tokenized_val.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    
    # Create dataloaders
    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(tokenized_val, batch_size=batch_size)
    
    print(f"Train samples: {len(tokenized_train)}, Validation samples: {len(tokenized_val)}")
    return tokenizer, train_loader, val_loader
    
def train_model(model, train_loader, val_loader, epochs=3, learning_rate=2e-5):
    """Train the model and return metrics."""
    print(f"Training model for {epochs} epochs...")
    
    # Optimizer
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    
    # Learning rate scheduler
    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)
    
    def warmup_cosine_scheduler(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        else:
            progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + math.cos(math.pi * progress))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)
    
    # Training tracking
    start_time = time.time()
    best_val_acc = 0.0
    final_train_loss = 0.0
    final_val_loss = 0.0
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            
            # Track metrics
            train_loss += loss.item() * labels.size(0)
            preds = torch.argmax(outputs.logits, dim=-1)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
            
            progress_bar.set_postfix({
                "loss": loss.item(),
                "acc": train_correct / train_total
            })
        
        # Calculate epoch stats
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total
        final_train_loss = epoch_train_loss
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["label"].to(device)
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                
                val_loss += loss.item() * labels.size(0)
                preds = torch.argmax(outputs.logits, dim=-1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        
        # Calculate validation stats
        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total
        final_val_loss = epoch_val_loss
        
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
        
        print(f"Epoch {epoch + 1}/{epochs}:")
        print(f"  Train Loss: {epoch_train_loss:.3f}, Train Acc: {epoch_train_acc:.3f}")
        print(f"  Val Loss: {epoch_val_loss:.3f}, Val Acc: {epoch_val_acc:.3f}")
    
    # Calculate total training time
    training_time = time.time() - start_time
    
    # Return final metrics
    return {
        'accuracy': best_val_acc,
        'train_loss': final_train_loss,
        'val_loss': final_val_loss,
        'training_time': training_time
    }


def format_time(seconds):
    """Format time in minutes."""
    minutes = int(seconds // 60)
    return f"{minutes}m"


def calculate_dilated_sparsity(seq_length, window_size, dilation_rates):
    """Calculate the sparsity ratio for dilated sliding window attention."""
    total_connections = 0
    
    for i in range(seq_length):
        connections = set()
        
        # Local sliding window (dilation=1)
        local_start = max(0, i - window_size // 2)
        local_end = min(seq_length, i + window_size // 2 + 1)
        connections.update(range(local_start, local_end))
        
        # Dilated connections
        for dilation in dilation_rates[1:]:  # Skip dilation=1
            for direction in [-1, 1]:
                for step in range(1, window_size // 2 + 1):
                    target_pos = i + direction * step * dilation
                    if 0 <= target_pos < seq_length:
                        connections.add(target_pos)
                    else:
                        break
        
        total_connections += len(connections)
    
    max_connections = seq_length * seq_length
    sparsity = 1 - (total_connections / max_connections)
    return sparsity, total_connections / seq_length  # Average connections per token

def run_experiment(seq_length, window_size=128, dilation_rates=[1, 2, 4, 8], epochs=3):
    """Run experiment for a specific sequence length."""
    print(f"\n=== Running experiment with seq_length={seq_length}, window_size={window_size} ===")
    print(f"Dilation rates: {dilation_rates}")
    
    # Calculate and display sparsity information
    sparsity, avg_connections = calculate_dilated_sparsity(seq_length, window_size, dilation_rates)
    print(f"Sparsity ratio: {sparsity:.1%}")
    print(f"Average connections per token: {avg_connections:.1f}")
    
    # Fixed parameters
    batch_size = 8
    train_samples = 100000  # 100K training samples
    val_samples = 20000    # 20K validation samples
    epochs = 3
    
    # Load data with specified sample sizes
    tokenizer, train_loader, val_loader = load_yelp_data(
        max_length=seq_length,
        batch_size=batch_size,
        train_samples=train_samples,
        val_samples=val_samples
    )
    
    
    # Create model configuration
    config = RobertaConfig.from_pretrained("roberta-base")
    config.num_hidden_layers = 3  # Smaller model for faster training
    config.num_labels = 2
    config.hidden_dropout_prob = 0.1
    config.attention_probs_dropout_prob = 0.1
    
    # Create model with dilated sliding window sparse attention
    model = DilatedSlidingWindowRobertaForSequenceClassification(
        config, 
        window_size=window_size, 
        dilation_rates=dilation_rates
    )
    model.to(device)
    
    # Train model
    metrics = train_model(
        model, 
        train_loader, 
        val_loader, 
        epochs=epochs,
        learning_rate=2e-5
    )
    
    # Add sparsity information to metrics
    metrics['sparsity_ratio'] = sparsity
    metrics['avg_connections_per_token'] = avg_connections
    metrics['window_size'] = window_size
    metrics['dilation_rates'] = dilation_rates
    
    return metrics


def main():
    """Main execution function."""
    print("=== RoBERTa with Dilated Sliding Window Sparse Attention ===\n")
    
    # Configuration for dilated sliding window attention
    WINDOW_SIZE = 128
    DILATION_RATES = [1, 2, 4, 8]  # Local + 3 dilated patterns
    
    # Sequence length to test (only 512)
    sequence_length = 512
    
    # Store results
    results = {}
    
    # Run experiment for 512 tokens only
    try:
        metrics = run_experiment(
            sequence_length, 
            WINDOW_SIZE, 
            DILATION_RATES
        )
        results[sequence_length] = metrics
        print(f"\nResults for seq_len={sequence_length}:")
        print(f"  Accuracy: {metrics['accuracy']:.1f}")
        print(f"  Loss: {metrics['val_loss']:.2f}")
        print(f"  Training Time: {format_time(metrics['training_time'])}")
        print(f"  Sparsity: {metrics['sparsity_ratio']:.1%}")
        print(f"  Avg Connections/Token: {metrics['avg_connections_per_token']:.1f}")
    except Exception as e:
        print(f"Error with seq_len={sequence_length}: {e}")
        results[sequence_length] = {
            'accuracy': 0.0, 
            'val_loss': float('inf'), 
            'training_time': 0,
            'sparsity_ratio': 0.0,
            'avg_connections_per_token': 0.0
        }
    
    # Create comparison table
    print("\n" + "="*90)
    print("COMPARISON TABLE - RoBERTa with Dilated Sliding Window Sparse Attention")
    print("="*90)
    print(f"{'Dataset':<10} {'Classes':<8} {'L':<6} {'Dilated HOBA (Ours)':<35} {'Sparsity':<10}")
    print(f"{'':>32} {'Acc (Loss)':<15} {'Time':<8} {'Conn/Tok':<10}")
    print("-"*90)
    
    if sequence_length in results:
        metrics = results[sequence_length]
        acc = metrics['accuracy'] * 100  # Convert to percentage
        loss = metrics['val_loss']
        time_str = format_time(metrics['training_time'])
        sparsity = metrics['sparsity_ratio'] * 100
        conn_per_tok = metrics['avg_connections_per_token']
        
        print(f"{'Yelp':<10} {'2':<8} {sequence_length:<6} "
              f"{acc:.1f} ({loss:.2f}){'':<6} {time_str:<8} "
              f"{conn_per_tok:.1f}{'':<6} {sparsity:.0f}%")
    else:
        print(f"{'Yelp':<10} {'2':<8} {sequence_length:<6} {'–':<30} {'–':<10}")
    
    # Save detailed results
    detailed_results = []
    if sequence_length in results:
        metrics = results[sequence_length]
        detailed_results.append({
            'sequence_length': sequence_length,
            'accuracy': metrics['accuracy'],
            'loss': metrics['val_loss'],
            'training_time_seconds': metrics['training_time'],
            'training_time_formatted': format_time(metrics['training_time']),
            'sparsity_ratio': metrics['sparsity_ratio'],
            'avg_connections_per_token': metrics['avg_connections_per_token'],
            'window_size': WINDOW_SIZE,
            'dilation_rates': str(DILATION_RATES),
            'attention_type': 'dilated_sliding_window'
        })
    
    # Save to CSV
    df = pd.DataFrame(detailed_results)
    df.to_csv("results/metrics/yelp_dilated_sliding_window_roberta_comparison.csv", index=False)
    
    # Save configuration
    config_info = {
        'model_type': 'RoBERTa with Dilated Sliding Window Sparse Attention',
        'window_size': WINDOW_SIZE,
        'dilation_rates': DILATION_RATES,
        'dataset': 'Yelp Polarity',
        'classes': 2,
        'sequence_length_tested': sequence_length,
        'model_layers': 3,
        'attention_type': 'dilated_sliding_window',
        'optimizer': 'AdamW',
        'learning_rate': 2e-5,
        'sparsity_pattern': 'dilated_sliding_window',
        'epochs': 3
    }
    
    import json
    with open("results/metrics/yelp_dilated_sliding_window_roberta_config.json", "w") as f:
        json.dump(config_info, f, indent=2)
    
    print(f"\nDetailed results saved to: results/metrics/yelp_dilated_sliding_window_roberta_comparison.csv")
    print(f"Configuration saved to: results/metrics/yelp_dilated_sliding_window_roberta_config.json")
    
    # print("\n=== Summary ===")
    # print(f"Model: RoBERTa with Dilated Sliding Window Sparse Attention")
    # print(f"Window Size: {WINDOW_SIZE}")
    # print(f"Dilation Rates: {DILATION_RATES}")
    # print("Dataset: Yelp Polarity (2 classes)")
    # print(f"Sequence length tested: {sequence_length}")
    # print("Epochs: 3")
    
    # Performance summary
    if sequence_length in results:
        metrics = results[sequence_length]
        acc = metrics['accuracy'] * 100
        time_taken = metrics['training_time']
        loss = metrics['val_loss']
        sparsity = metrics['sparsity_ratio']
        conn_per_tok = metrics['avg_connections_per_token']
        
        print(f"\nFinal Results:")
        print(f"  Accuracy: {acc:.1f}%")
        print(f"  Loss: {loss:.2f}")
        print(f"  Training Time: {format_time(time_taken)}")
        print(f"  Sparsity Ratio: {sparsity:.1%}")
        print(f"  Avg Connections per Token: {conn_per_tok:.1f}")
        print(f"  Effective Attention Reduction: {(1-sparsity):.1%} of full attention")
        
        # Compare with original sliding window
        original_sparsity = 1 - min(WINDOW_SIZE, sequence_length) / sequence_length
        print(f"\nComparison with Basic Sliding Window:")
        print(f"  Basic Sliding Window Sparsity: {original_sparsity:.1%}")
        print(f"  Dilated Sliding Window Sparsity: {sparsity:.1%}")
        print(f"  Additional Long-range Connections: {conn_per_tok - WINDOW_SIZE:.1f} per token")
    else:
        print("\nNo results to display due to errors during training.")


if __name__ == "__main__":
    main()
