#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
RoBERTa 3D Block Diagonal for IMDB Binary Classification
 for comparison with Longformer and BigBird baselines
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
import time
import json
import warnings
warnings.filterwarnings('ignore')

# Clear datasets cache to avoid issues
import datasets
datasets.disable_caching()

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

from transformers import (
    RobertaTokenizer, 
    RobertaModel, 
    RobertaConfig,
    RobertaForSequenceClassification
)

from transformers.models.roberta.modeling_roberta import (
    RobertaAttention, RobertaSelfAttention, RobertaLayer,
    RobertaEncoder, RobertaClassificationHead
)

from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


class IMDBDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


class BlockDiagonal3DAttention(nn.Module):
    """
    3D Block Diagonal Higher-Order Attention with numerical stability improvements.
    """

    def __init__(self, config, order=2, block_size=64, overlap_ratio=0.25):
        super().__init__()
        self.order = order
        self.block_size = block_size
        self.overlap_size = int(block_size * overlap_ratio)
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Projections for query and keys
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key1 = nn.Linear(config.hidden_size, self.all_head_size)
        self.key2 = nn.Linear(config.hidden_size, self.all_head_size)

        # Projections for values
        self.value1 = nn.Linear(config.hidden_size, self.all_head_size)
        self.value2 = nn.Linear(config.hidden_size, self.all_head_size)

        # Output projection
        self.output_projection = nn.Linear(self.all_head_size, config.hidden_size)

        # Cross-block interaction layer
        self.cross_block_attention = nn.MultiheadAttention(
            config.hidden_size, 
            config.num_attention_heads, 
            dropout=config.attention_probs_dropout_prob,
            batch_first=True
        )

        # Boundary scorer with proper initialization
        self.boundary_scorer = nn.Linear(config.hidden_size, 1)
        
        # Dropout
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        
        # Layer normalization for stability
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        
        # Scaling factor for numerical stability
        self.scale_factor = 1.0 / math.sqrt(self.attention_head_size)

        self.initialize_weights()

    def initialize_weights(self):
        """Initialize weights with smaller variance for stability."""
        for module in [self.query, self.key1, self.key2, self.value1, self.value2, 
                      self.output_projection]:
            # Use Xavier initialization with smaller scale
            nn.init.xavier_uniform_(module.weight, gain=0.5)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        
        # Initialize boundary scorer with very small weights
        nn.init.normal_(self.boundary_scorer.weight, mean=0.0, std=0.01)
        nn.init.zeros_(self.boundary_scorer.bias)

    def transpose_for_scores(self, x):
        """Reshape for multi-head attention."""
        new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_shape)
        return x.permute(0, 2, 1, 3)

    def get_fixed_blocks(self, seq_len):
        """Use fixed block boundaries for stability."""
        blocks = []
        current_pos = 0
        
        while current_pos < seq_len:
            end_pos = min(current_pos + self.block_size, seq_len)
            blocks.append((current_pos, end_pos))
            current_pos += self.block_size - self.overlap_size
            
        return blocks

    def compute_3d_block_attention(self, q_block, k1_block, k2_block, v1_block, v2_block, block_mask=None):
        """
        Compute 3D block diagonal higher-order attention within a single block.
        """
        block_len = q_block.size(1)
        
        # Clip block length to prevent memory issues
        if block_len > self.block_size:
            q_block = q_block[:, :self.block_size, :]
            k1_block = k1_block[:, :self.block_size, :]
            k2_block = k2_block[:, :self.block_size, :]
            v1_block = v1_block[:, :self.block_size, :]
            v2_block = v2_block[:, :self.block_size, :]
            block_len = self.block_size

        # For very small blocks, use standard attention
        if block_len <= 2:
            attention_scores = torch.matmul(q_block, k1_block.transpose(-1, -2)) * self.scale_factor
            if block_mask is not None:
                attention_scores += block_mask.unsqueeze(0).unsqueeze(-1)
            attention_probs = F.softmax(attention_scores, dim=-1)
            context = torch.matmul(attention_probs, v1_block)
            return context

        # Compute 3D attention tensor with stability checks
        try:
            # Use smaller blocks to prevent memory overflow for 3D operations
            max_3d_size = min(block_len, 12)  # Smaller limit for 3D operations
            
            if block_len > max_3d_size:
                # Create 3D block diagonal pattern by sampling key positions
                step = max(1, block_len // max_3d_size)
                indices = torch.arange(0, block_len, step, dtype=torch.long)[:max_3d_size]
                k1_sampled = k1_block[:, indices, :]
                k2_sampled = k2_block[:, indices, :]
                v1_sampled = v1_block[:, indices, :]
                v2_sampled = v2_block[:, indices, :]
            else:
                k1_sampled = k1_block
                k2_sampled = k2_block
                v1_sampled = v1_block
                v2_sampled = v2_block
                
            sampled_len = k1_sampled.size(1)
            
            # Create 3D block diagonal structure
            # Compute 3D attention tensor: (num_heads, block_len, sampled_len, sampled_len)
            # This creates higher-order interactions in 3D space
            qk1k2_3d = torch.einsum("HiD,HjD,HkD->Hijk", q_block, k1_sampled, k2_sampled)
            
            # Apply 3D block diagonal masking
            # Create block diagonal mask in 3D space
            block_mask_3d = torch.zeros(sampled_len, sampled_len, device=q_block.device)
            block_diag_size = max(1, sampled_len // 4)  # Create 4 diagonal blocks
            
            for i in range(0, sampled_len, block_diag_size):
                for j in range(0, sampled_len, block_diag_size):
                    end_i = min(i + block_diag_size, sampled_len)
                    end_j = min(j + block_diag_size, sampled_len)
                    if abs(i - j) <= block_diag_size:  # Allow some off-diagonal blocks
                        block_mask_3d[i:end_i, j:end_j] = 1.0
            
            # Apply 3D block diagonal pattern
            qk1k2_3d = qk1k2_3d * block_mask_3d.unsqueeze(0).unsqueeze(0)
            
            # Apply scaling for numerical stability
            qk1k2_3d = qk1k2_3d * self.scale_factor
            
            # Check for NaN values
            if torch.isnan(qk1k2_3d).any():
                print("Warning: NaN detected in 3D attention computation, using fallback")
                attention_scores = torch.matmul(q_block, k1_block.transpose(-1, -2)) * self.scale_factor
                attention_probs = F.softmax(attention_scores, dim=-1)
                context = torch.matmul(attention_probs, v1_block)
                return context
            
            # Reshape for softmax and apply
            qk1k2_flat = qk1k2_3d.reshape(self.num_attention_heads, block_len, -1)
            
            # Clamp values to prevent overflow
            qk1k2_flat = torch.clamp(qk1k2_flat, min=-50, max=50)
            
            attention_probs_flat = F.softmax(qk1k2_flat, dim=-1)
            attention_probs_3d = attention_probs_flat.reshape(
                self.num_attention_heads, block_len, sampled_len, sampled_len
            )
            
            # Apply dropout
            attention_probs_3d = self.dropout(attention_probs_3d)
            
            # Compute 3D value tensor products for higher-order interactions
            v1v2_3d = torch.einsum("HiD,HjD->HijD", v1_sampled, v2_sampled)
            
            # Apply 3D attention to 3D values
            context = torch.einsum("Hijk,HjkD->HiD", attention_probs_3d, v1v2_3d)
            
            # Check for NaN in output
            if torch.isnan(context).any():
                print("Warning: NaN in 3D context, using standard attention fallback")
                attention_scores = torch.matmul(q_block, k1_block.transpose(-1, -2)) * self.scale_factor
                attention_probs = F.softmax(attention_scores, dim=-1)
                context = torch.matmul(attention_probs, v1_block)
            
            return context
            
        except Exception as e:
            print(f"Error in 3D block attention computation: {e}, using fallback")
            # Fallback to standard attention
            attention_scores = torch.matmul(q_block, k1_block.transpose(-1, -2)) * self.scale_factor
            attention_probs = F.softmax(attention_scores, dim=-1)
            context = torch.matmul(attention_probs, v1_block)
            return context

    def forward(self, hidden_states, attention_mask=None):
        """Forward pass with 3D block diagonal structure."""
        batch_size, seq_length = hidden_states.size(0), hidden_states.size(1)

        # Apply layer normalization for stability
        hidden_states = self.layer_norm(hidden_states)

        # Project inputs
        query_layer = self.query(hidden_states)
        key1_layer = self.key1(hidden_states)
        key2_layer = self.key2(hidden_states)
        value1_layer = self.value1(hidden_states)
        value2_layer = self.value2(hidden_states)

        # Reshape to multi-head format
        query_layer = self.transpose_for_scores(query_layer)
        key1_layer = self.transpose_for_scores(key1_layer)
        key2_layer = self.transpose_for_scores(key2_layer)
        value1_layer = self.transpose_for_scores(value1_layer)
        value2_layer = self.transpose_for_scores(value2_layer)

        # Use fixed blocks for stability
        blocks = self.get_fixed_blocks(seq_length)

        # Process each batch item
        batch_outputs = []
        
        for b in range(batch_size):
            # Initialize output
            batch_context = torch.zeros_like(query_layer[b])
            overlap_counts = torch.ones(seq_length, device=hidden_states.device)
            
            # Process each block with 3D structure
            for start, end in blocks:
                # Extract block data
                q_block = query_layer[b, :, start:end, :]
                k1_block = key1_layer[b, :, start:end, :]
                k2_block = key2_layer[b, :, start:end, :]
                v1_block = value1_layer[b, :, start:end, :]
                v2_block = value2_layer[b, :, start:end, :]
                
                # Get block mask
                block_mask = None
                if attention_mask is not None:
                    mask_slice = attention_mask[b, start:end]
                    if mask_slice.dim() > 1:
                        block_mask = mask_slice
                    else:
                        block_mask = (1.0 - mask_slice.float()) * -10000.0
                
                # Compute 3D block diagonal attention
                block_context = self.compute_3d_block_attention(
                    q_block, k1_block, k2_block, v1_block, v2_block, block_mask
                )
                
                # Accumulate results
                batch_context[:, start:end, :] += block_context
                overlap_counts[start:end] += 1
            
            # Average overlapping regions
            batch_context = batch_context / overlap_counts.unsqueeze(0).unsqueeze(-1)
            batch_outputs.append(batch_context)

        # Stack batch results
        context_layer = torch.stack(batch_outputs)

        # Reshape back
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_shape)

        # Final projection
        attention_output = self.output_projection(context_layer)
        
        # Check for NaN in final output
        if torch.isnan(attention_output).any():
            print("Warning: NaN in final attention output, replacing with input")
            attention_output = hidden_states

        return attention_output


class Block3DRobertaSelfAttention(RobertaSelfAttention):
    """RobertaSelfAttention with 3D block diagonal structure."""

    def __init__(self, config, position_embedding_type=None, order=2, block_size=64, self_attn=None):
        super().__init__(config, position_embedding_type)
        
        self.block_3d_attention = BlockDiagonal3DAttention(
            config, order=order, block_size=block_size
        )

        # Copy weights from standard attention if provided
        if self_attn is not None:
            with torch.no_grad():
                # Copy with small perturbations to avoid identical weights
                self.block_3d_attention.query.weight.data = self_attn.query.weight.data.clone()
                self.block_3d_attention.query.bias.data = self_attn.query.bias.data.clone()

                self.block_3d_attention.key1.weight.data = self_attn.key.weight.data.clone()
                self.block_3d_attention.key1.bias.data = self_attn.key.bias.data.clone()

                # Add small noise for key2 to create diversity
                noise = torch.randn_like(self_attn.key.weight.data) * 0.01
                self.block_3d_attention.key2.weight.data = self_attn.key.weight.data.clone() + noise
                self.block_3d_attention.key2.bias.data = self_attn.key.bias.data.clone()

                self.block_3d_attention.value1.weight.data = self_attn.value.weight.data.clone()
                self.block_3d_attention.value1.bias.data = self_attn.value.bias.data.clone()

                # Add small noise for value2 to create diversity
                noise = torch.randn_like(self_attn.value.weight.data) * 0.01
                self.block_3d_attention.value2.weight.data = self_attn.value.weight.data.clone() + noise
                self.block_3d_attention.value2.bias.data = self_attn.value.bias.data.clone()

    def forward(self, hidden_states, attention_mask=None, head_mask=None,
                encoder_hidden_states=None, encoder_attention_mask=None,
                past_key_value=None, output_attentions=False):
        """Forward pass with error handling."""
        if encoder_hidden_states is not None:
            raise NotImplementedError("Cross-attention not implemented")

        # Process attention mask
        if attention_mask is not None:
            extended_attention_mask = attention_mask
            if extended_attention_mask.dtype != torch.float32:
                extended_attention_mask = (1.0 - extended_attention_mask.to(torch.float32)) * -10000.0
        else:
            extended_attention_mask = None

        # Call 3D block diagonal attention with error handling
        try:
            context_layer = self.block_3d_attention(hidden_states, extended_attention_mask)
        except Exception as e:
            print(f"Error in 3D block diagonal attention: {e}, using standard attention")
            # Fallback to standard attention
            context_layer = hidden_states

        outputs = (context_layer,)
        if output_attentions:
            # Dummy attention tensor
            batch_size, seq_length = hidden_states.size(0), hidden_states.size(1)
            attention_probs = torch.zeros(
                batch_size, self.num_attention_heads, seq_length, seq_length,
                device=hidden_states.device, dtype=hidden_states.dtype
            )
            outputs = outputs + (attention_probs,)

        return outputs


class Block3DRobertaAttention(RobertaAttention):
    """RobertaAttention with 3D block diagonal structure."""

    def __init__(self, config, position_embedding_type=None, order=2, block_size=64, attn=None):
        super().__init__(config, position_embedding_type)
        if attn is not None:
            self.self = Block3DRobertaSelfAttention(
                config,
                position_embedding_type=position_embedding_type,
                order=order,
                block_size=block_size,
                self_attn=attn.self
            )
            self.output = attn.output

            # Copy output projection weights
            with torch.no_grad():
                self.self.block_3d_attention.output_projection.weight.data = attn.output.dense.weight.data.clone()
                self.self.block_3d_attention.output_projection.bias.data = attn.output.dense.bias.data.clone()
        else:
            self.self = Block3DRobertaSelfAttention(
                config,
                position_embedding_type=position_embedding_type,
                order=order,
                block_size=block_size
            )


class Block3DRobertaLayer(RobertaLayer):
    """RobertaLayer with 3D block diagonal attention."""

    def __init__(self, config, order=2, block_size=64, layer=None):
        super().__init__(config)
        if layer is not None:
            self.attention = Block3DRobertaAttention(
                config,
                order=order,
                block_size=block_size,
                attn=layer.attention
            )
            self.intermediate = layer.intermediate
            self.output = layer.output
        else:
            self.attention = Block3DRobertaAttention(
                config, 
                order=order, 
                block_size=block_size
            )


class Block3DRobertaEncoder(RobertaEncoder):
    """RobertaEncoder with 3D block diagonal attention."""

    def __init__(self, config, order=2, block_size=64, original_encoder=None):
        super().__init__(config)
        if original_encoder is not None:
            self.layer = nn.ModuleList([
                Block3DRobertaLayer(
                    config, 
                    order=order, 
                    block_size=block_size,
                    layer=original_encoder.layer[i]
                )
                for i in range(config.num_hidden_layers)
            ])
        else:
            self.layer = nn.ModuleList([
                Block3DRobertaLayer(
                    config, 
                    order=order, 
                    block_size=block_size
                )
                for _ in range(config.num_hidden_layers)
            ])


class Block3DRobertaModel(RobertaModel):
    """RobertaModel with 3D block diagonal attention."""

    def __init__(self, config, order=2, block_size=64, original_model=None):
        super().__init__(config)
        if original_model is not None:
            self.embeddings = original_model.embeddings
            self.encoder = Block3DRobertaEncoder(
                config,
                order=order,
                block_size=block_size,
                original_encoder=original_model.encoder
            )
            self.pooler = original_model.pooler
        else:
            self.encoder = Block3DRobertaEncoder(
                config, 
                order=order, 
                block_size=block_size
            )

        self.post_init()


class Block3DRobertaForSequenceClassification(RobertaForSequenceClassification):
    """RobertaForSequenceClassification with 3D block diagonal attention."""

    def __init__(self, config, order=2, block_size=64, original_model=None):
        super().__init__(config)
        if original_model is not None:
            self.roberta = Block3DRobertaModel(
                config,
                order=order,
                block_size=block_size,
                original_model=original_model.roberta
            )
            self.classifier = original_model.classifier
        else:
            self.roberta = Block3DRobertaModel(
                config, 
                order=order, 
                block_size=block_size
            )

        self.post_init()


def load_imdb_data():
    """Load and prepare IMDB dataset"""
    print("Loading IMDB dataset...")
    dataset = load_dataset('imdb')
    
    train_texts = dataset['train']['text']
    train_labels = dataset['train']['label']
    test_texts = dataset['test']['text']
    test_labels = dataset['test']['label']
    
    # Use stratified subset for computational efficiency (same as baselines)
    print("Using stratified subset (20%) for computational efficiency...")
    train_texts, _, train_labels, _ = train_test_split(
        train_texts, train_labels, train_size=0.2, stratify=train_labels, random_state=42
    )
    test_texts, _, test_labels, _ = train_test_split(
        test_texts, test_labels, train_size=0.2, stratify=test_labels, random_state=42
    )
    
    print(f"Training samples: {len(train_texts)}")
    print(f"Test samples: {len(test_texts)}")
    print(f"Training label distribution: {np.bincount(train_labels)}")
    print(f"Test label distribution: {np.bincount(test_labels)}")
    
    return train_texts, train_labels, test_texts, test_labels


def create_data_loaders(train_texts, train_labels, test_texts, test_labels, 
                       tokenizer, max_length=512, batch_size=4):
    train_dataset = IMDBDataset(train_texts, train_labels, tokenizer, max_length)
    test_dataset = IMDBDataset(test_texts, test_labels, tokenizer, max_length)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader


def train_model(model, train_loader, optimizer, scheduler, device, epoch, max_grad_norm=1.0):
    """Train the model for one epoch"""
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []
    
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
    
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss
        logits = outputs.logits
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())
        
        progress_bar.set_postfix({
            'loss': loss.item(),
            'lr': scheduler.get_last_lr()[0]
        })
    
    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(true_labels, predictions)
    
    return avg_loss, accuracy


def evaluate_model(model, test_loader, device):
    """Evaluate the model on test data"""
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc='Evaluating')
        
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            logits = outputs.logits
            
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(test_loader)
    accuracy = accuracy_score(true_labels, predictions)
    
    return avg_loss, accuracy, predictions, true_labels


def measure_performance(model, data_loader, device):
    """Measure inference time and memory usage"""
    model.eval()
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    
    start_time = time.time()
    total_samples = 0
    
    with torch.no_grad():
        for batch in tqdm(data_loader, desc='Measuring performance'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            total_samples += input_ids.size(0)
    
    end_time = time.time()
    
    memory_used = 0
    if torch.cuda.is_available():
        memory_used = torch.cuda.max_memory_allocated() / 1024**3  # GB
    
    inference_time = end_time - start_time
    samples_per_second = total_samples / inference_time
    
    return {
        'inference_time': inference_time,
        'samples_per_second': samples_per_second,
        'memory_used_gb': memory_used,
        'total_samples': total_samples
    }


def save_comparison_results(results, filename='roberta_3d_baseline_results.json'):
    """Save results for comparison with Longformer and BigBird"""
    comparison_data = {
        'model_name': 'RoBERTa 3D Block Diagonal',
        'dataset': 'IMDB Binary Classification',
        'hyperparameters': {
            'max_length': results.get('max_length', 512),
            'batch_size': results.get('batch_size', 4),
            'learning_rate': results.get('learning_rate', 2e-5),
            'epochs': results.get('epochs', 5),
            'weight_decay': results.get('weight_decay', 0.01),
            'block_size': results.get('block_size', 64),
            'order': results.get('order', 2)
        },
        'performance_metrics': {
            'best_accuracy': results['best_accuracy'],
            'final_train_loss': results['final_train_loss'],
            'final_test_loss': results['final_test_loss'],
            'inference_time': results.get('inference_time', 0),
            'samples_per_second': results.get('samples_per_second', 0),
            'memory_used_gb': results.get('memory_used_gb', 0)
        },
        'model_size': {
            'parameters': sum(p.numel() for p in results['model'].parameters()),
            'trainable_parameters': sum(p.numel() for p in results['model'].parameters() if p.requires_grad)
        },
        'attention_pattern': '3D Block Diagonal Higher-Order',
        'complexity': 'O(n) with 3D block structure'
    }
    
    with open(filename, 'w') as f:
        json.dump(comparison_data, f, indent=2)
    
    print(f"RoBERTa 3D Block Diagonal results saved to {filename}")
    return comparison_data


def main():
    """Run RoBERTa 3D Block Diagonal experiment for IMDB comparison"""
    print("\n=== RoBERTa 3D Block Diagonal for IMDB Comparison ===\n")
    
    # Hyperparameters matching baselines for fair comparison
    MAX_LENGTH = 512       # RoBERTa's typical max length
    BATCH_SIZE = 8         # Same as baselines
    LEARNING_RATE = 2e-5   # Standard for transformer fine-tuning
    EPOCHS = 9             # Sufficient for convergence
    WARMUP_STEPS = 500     # 10% of total steps
    WEIGHT_DECAY = 0.01    # L2 regularization
    MAX_GRAD_NORM = 1.0    # Gradient clipping
    BLOCK_SIZE = 16        # 3D block size
    ORDER = 2              # Higher-order interactions
    
    print(f"Hyperparameters:")
    print(f"  Max Length: {MAX_LENGTH}")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  Learning Rate: {LEARNING_RATE}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Block Size: {BLOCK_SIZE}")
    print(f"  Order: {ORDER}")
    
    # Load data
    train_texts, train_labels, test_texts, test_labels = load_imdb_data()
    
    # Initialize tokenizer
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    
    # Load base RoBERTa model
    print("Loading base RoBERTa model...")
    base_model = RobertaForSequenceClassification.from_pretrained(
        'roberta-base',
        num_labels=2  # Binary classification
    )
    
    # Create 3D Block Diagonal model
    print(f"Creating RoBERTa 3D Block Diagonal model (block_size={BLOCK_SIZE}, order={ORDER})...")
    model = Block3DRobertaForSequenceClassification(
        base_model.config,
        order=ORDER,
        block_size=BLOCK_SIZE,
        original_model=base_model
    )
    
    # Move model to device
    model.to(device)
    
    print(f"Model loaded with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
    
    # Create data loaders
    train_loader, test_loader = create_data_loaders(
        train_texts, train_labels, test_texts, test_labels,
        tokenizer, MAX_LENGTH, BATCH_SIZE
    )
    
    # Setup optimizer and scheduler
    optimizer = AdamW(
        model.parameters(), 
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        eps=1e-8
    )
    
    total_steps = len(train_loader) * EPOCHS
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=0.1,
        total_iters=WARMUP_STEPS
    )
    
    print(f"Starting training for {EPOCHS} epochs...")
    print(f"Total training steps: {total_steps}")
    
    # Training loop
    best_accuracy = 0
    results = {
        'max_length': MAX_LENGTH,
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'epochs': EPOCHS,
        'weight_decay': WEIGHT_DECAY,
        'block_size': BLOCK_SIZE,
        'order': ORDER
    }
    
    for epoch in range(EPOCHS):
        print(f"\n{'='*50}")
        print(f"Epoch {epoch + 1}/{EPOCHS}")
        print('='*50)
        
        # Train
        train_loss, train_accuracy = train_model(
            model, train_loader, optimizer, scheduler, device, epoch, MAX_GRAD_NORM
        )
        
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        
        # Evaluate
        test_loss, test_accuracy, predictions, true_labels = evaluate_model(
            model, test_loader, device
        )
        
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
        
        # Save best model
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            torch.save(model.state_dict(), 'results/models/roberta_3d_best.pt')
            print(f"New best model saved! Accuracy: {best_accuracy:.4f}")
            
            # Store results for final comparison
            results.update({
                'best_accuracy': best_accuracy,
                'final_train_loss': train_loss,
                'final_test_loss': test_loss,
                'model': model
            })
    
    # Measure performance for comparison
    print(f"\n{'='*50}")
    print("MEASURING PERFORMANCE FOR COMPARISON")
    print('='*50)
    
    perf_metrics = measure_performance(model, test_loader, device)
    results.update(perf_metrics)
    
    print(f"Inference time: {perf_metrics['inference_time']:.2f} seconds")
    print(f"Samples per second: {perf_metrics['samples_per_second']:.2f}")
    print(f"Memory used: {perf_metrics['memory_used_gb']:.2f} GB")
    
    # Save comparison results
    comparison_data = save_comparison_results(results)
    
    # Final evaluation with classification report
    print(f"\n{'='*50}")
    print("FINAL RESULTS - RoBERTa 3D BLOCK DIAGONAL")
    print('='*50)
    print(f" Best Test Accuracy: {best_accuracy:.4f}")
    print(f" Model Parameters: {comparison_data['model_size']['parameters']:,}")
    print(f" Attention Pattern: 3D Block Diagonal Higher-Order")
    print(f" Block Size: {BLOCK_SIZE}")
    print(f" Order: {ORDER}")
    print(f" Inference Speed: {perf_metrics['samples_per_second']:.2f} samples/sec")
    print(f" Memory Usage: {perf_metrics['memory_used_gb']:.2f} GB")
    
    print("\nClassification Report:")
    print(classification_report(true_labels, predictions, 
                              target_names=['Negative', 'Positive']))
    
    print(f"\n 3D BLOCK DIAGONAL BASELINE ESTABLISHED!")
    print("="*60)
    print("Compare with:")
    print("• Longformer (sliding window attention)")
    print("• BigBird (sparse + random + global attention)")
    print("• Your approach shows 3D higher-order interactions")
    print("• Block diagonal structure for efficiency")
    print("• O(n) complexity with enhanced modeling capacity")
    
    return results, comparison_data


if __name__ == "__main__":
    # Run the main experiment
    results, comparison_data = main()
