#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
OPTIMIZED BigBird with BLOCK SPARSE attention for Yelp Binary Classification
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
import time
import json
import warnings
warnings.filterwarnings('ignore')

# Clear datasets cache
import datasets
datasets.disable_caching()

from datasets import load_dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

from transformers import (
    BigBirdTokenizer,
    BigBirdModel,
    BigBirdConfig,
    BigBirdForSequenceClassification
)

from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

# SPEED OPTIMIZATION 1: Enable optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# Set device with optimizations
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set up directories
os.makedirs("results", exist_ok=True)
os.makedirs("results/models", exist_ok=True)
os.makedirs("results/metrics", exist_ok=True)

# Set random seed
torch.manual_seed(42)
np.random.seed(42)


class OptimizedYelpDataset(Dataset):
    """SPEED OPTIMIZATION 2: Pre-tokenized dataset with caching for Yelp"""
    def __init__(self, texts, labels, tokenizer, max_length=512, cache_file=None):
        self.labels = labels
        self.max_length = max_length
        
        # Check if cached tokenized data exists
        if cache_file and os.path.exists(cache_file):
            print(f"📁 Loading cached tokenized data from {cache_file}...")
            import pickle
            with open(cache_file, 'rb') as f:
                self.encodings = pickle.load(f)
            print(f"✅ Loaded {len(self.encodings)} cached encodings")
        else:
            # Pre-tokenize all data for speed
            print(f"🔤 Pre-tokenizing {len(texts)} samples for BigBird...")
            self.encodings = []

            for i, text in enumerate(tqdm(texts, desc="Tokenizing")):
                encoding = tokenizer(
                    str(text),
                    truncation=True,
                    padding='max_length',
                    max_length=max_length,
                    return_tensors='pt'
                )
                self.encodings.append({
                    'input_ids': encoding['input_ids'].flatten(),
                    'attention_mask': encoding['attention_mask'].flatten()
                })
            
            # Save tokenized data for future use
            if cache_file:
                print(f"💾 Saving tokenized data to {cache_file}...")
                os.makedirs(os.path.dirname(cache_file), exist_ok=True)
                import pickle
                with open(cache_file, 'wb') as f:
                    pickle.dump(self.encodings, f)
                print(f"✅ Cached {len(self.encodings)} encodings")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.encodings[idx]['input_ids'],
            'attention_mask': self.encodings[idx]['attention_mask'],
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }


def calculate_optimal_sequence_length(block_size):
    """
    Calculate optimal sequence length for block_sparse attention.
    For block_sparse, sequence length should be divisible by block_size.
    """
    # Common sequence lengths that work well with different block sizes
    if block_size == 64:
        return 512  # 512 / 64 = 8 blocks
    elif block_size == 32:
        return 512  # 512 / 32 = 16 blocks
    elif block_size == 16:
        return 512  # 512 / 16 = 32 blocks
    else:
        # Find the closest multiple of block_size <= 512
        return (512 // block_size) * block_size


class OptimizedBigBirdForSequenceClassification(nn.Module):
    """
    Optimized BigBird with BLOCK SPARSE attention
    """
    def __init__(self, config, attention_type="block_sparse", block_size=64, num_random_blocks=3):
        super().__init__()
        self.num_labels = config.num_labels
        self.config = config

        # IMPORTANT: Set BigBird specific parameters for block_sparse
        config.attention_type = attention_type
        config.block_size = block_size
        config.num_random_blocks = num_random_blocks
        
        # For block_sparse, we need to ensure proper configuration
        if attention_type == "block_sparse":
            # Set the number of global tokens (usually 2 for CLS and SEP)
            config.num_global_tokens = 2
            
            print(f"🔧 BLOCK SPARSE CONFIGURATION:")
            print(f"  Attention Type: {attention_type}")
            print(f"  Block Size: {block_size}")
            print(f"  Random Blocks: {num_random_blocks}")
            print(f"  Global Tokens: {config.num_global_tokens}")

        # Initialize base BigBird model
        self.bigbird = BigBirdModel(config)

        # Classification head
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        """Optimized forward pass with block sparse attention"""

        # Forward pass through BigBird
        outputs = self.bigbird(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
            **kwargs
        )

        # Get sequence output and apply classification head
        sequence_output = outputs.last_hidden_state
        # Use [CLS] token representation (first token)
        cls_output = sequence_output[:, 0, :]  # Shape: [batch_size, hidden_size]

        # Apply dropout and classification
        cls_output = self.dropout(cls_output)
        logits = self.classifier(cls_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return {
            'loss': loss,
            'logits': logits,
            'hidden_states': outputs.hidden_states,
            'attentions': outputs.attentions
        }


def load_yelp_data_fast(train_size=50000, test_size=10000):
    """OPTIMIZATION: Load medium subset of Yelp dataset for good balance of speed and accuracy"""
    print(f"Loading medium Yelp subset: {train_size} train, {test_size} test samples...")

    try:
        # Try to load the Yelp polarity dataset (binary classification)
        dataset = load_dataset('yelp_polarity')

        train_texts = dataset['train']['text']
        train_labels = dataset['train']['label']
        test_texts = dataset['test']['text']
        test_labels = dataset['test']['label']

        print(f"Loaded Yelp Polarity dataset")
        print(f"Labels: 0 (negative), 1 (positive)")

    except Exception as e:
        print(f"Failed to load yelp_polarity: {e}")
        print("Trying yelp_review_full dataset and converting to binary...")

        try:
            # Load full Yelp dataset and convert to binary
            dataset = load_dataset('yelp_review_full')

            train_texts = dataset['train']['text']
            train_labels_full = dataset['train']['label']
            test_texts = dataset['test']['text']
            test_labels_full = dataset['test']['label']

            # Convert 5-class to binary: 0,1,2 -> 0 (negative), 3,4 -> 1 (positive)
            train_labels = [0 if label < 3 else 1 for label in train_labels_full]
            test_labels = [0 if label < 3 else 1 for label in test_labels_full]

            print(f"Loaded Yelp Review Full dataset and converted to binary")
            print(f"Original labels 0,1,2 -> 0 (negative), labels 3,4 -> 1 (positive)")

        except Exception as e2:
            print(f"Failed to load yelp datasets: {e2}")
            print("Please install the datasets library and check your internet connection")
            raise e2

    # 🚀 SPEED OPTIMIZATION: Use medium subset for good balance
    print(f"\n🚀 USING MEDIUM SUBSET FOR BALANCED TRAINING:")
    print(f"  Original train size: {len(train_texts)}")
    print(f"  Using train size: {train_size}")
    print(f"  Original test size: {len(test_texts)}")
    print(f"  Using test size: {test_size}")
    
    # Take a stratified sample to maintain label balance
    train_texts, _, train_labels, _ = train_test_split(
        train_texts, train_labels, 
        train_size=train_size, 
        stratify=train_labels, 
        random_state=42
    )
    
    test_texts, _, test_labels, _ = train_test_split(
        test_texts, test_labels, 
        train_size=test_size, 
        stratify=test_labels, 
        random_state=42
    )

    print(f"\nFinal dataset sizes:")
    print(f"  Training samples: {len(train_texts)}")
    print(f"  Test samples: {len(test_texts)}")
    print(f"  Label distribution - Train: {np.bincount(train_labels)}")
    print(f"  Label distribution - Test: {np.bincount(test_labels)}")

    return train_texts, train_labels, test_texts, test_labels


def create_fast_data_loaders(train_texts, train_labels, test_texts, test_labels,
                            tokenizer, max_length=512, batch_size=8, use_cache=True):
    """OPTIMIZATION: Data loaders with caching for fast subsequent runs"""
    
    # Create cache file names based on dataset size and max_length
    cache_dir = "cache"
    train_cache = f"{cache_dir}/train_cache_{len(train_texts)}_{max_length}.pkl" if use_cache else None
    test_cache = f"{cache_dir}/test_cache_{len(test_texts)}_{max_length}.pkl" if use_cache else None
    
    train_dataset = OptimizedYelpDataset(
        train_texts, train_labels, tokenizer, max_length, train_cache
    )
    test_dataset = OptimizedYelpDataset(
        test_texts, test_labels, tokenizer, max_length, test_cache
    )

    # Reduced workers and no pin_memory for memory efficiency
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=False)

    return train_loader, test_loader


def train_model_fast(model, train_loader, optimizer, scheduler, device, epoch,
                    max_grad_norm=1.0, accumulation_steps=4):
    """OPTIMIZATION: Gradient accumulation for effective larger batch size"""
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}')

    # OPTIMIZATION: Mixed precision training
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    for i, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device, non_blocking=True)
        attention_mask = batch['attention_mask'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)

        # OPTIMIZATION: Mixed precision forward pass with memory management
        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs['loss'] / accumulation_steps  # Scale loss for accumulation
        else:
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs['loss'] / accumulation_steps

        # OPTIMIZATION: Gradient accumulation with memory cleanup
        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        # Clear intermediate variables to save memory
        del outputs

        if (i + 1) % accumulation_steps == 0:
            if scaler is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()

            scheduler.step()
            optimizer.zero_grad()

            # Clear cache periodically
            if torch.cuda.is_available() and (i + 1) % (accumulation_steps * 10) == 0:
                torch.cuda.empty_cache()

        total_loss += loss.item() * accumulation_steps
        logits = outputs['logits'] if 'outputs' in locals() else None
        if logits is not None:
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

        progress_bar.set_postfix({
            'loss': loss.item() * accumulation_steps,
            'lr': scheduler.get_last_lr()[0]
        })

    avg_loss = total_loss / len(train_loader)
    accuracy = accuracy_score(true_labels, predictions)

    return avg_loss, accuracy


def main_bigbird_block_sparse():
    """OPTIMIZATION: Fast training setup for Yelp dataset with BigBird Block Sparse"""
    print("\n=== OPTIMIZED BigBird BLOCK SPARSE for Yelp Reviews ===\n")

    # BLOCK SPARSE specific parameters
    BLOCK_SIZE = 64            # Block size for sparse attention
    NUM_RANDOM_BLOCKS = 3      # Number of random attention blocks
    ATTENTION_TYPE = "block_sparse"
    
    # Calculate optimal sequence length for the block size
    MAX_LENGTH = calculate_optimal_sequence_length(BLOCK_SIZE)
    
    # Memory-optimized hyperparameters
    BATCH_SIZE = 8             # Reduced for memory efficiency
    LEARNING_RATE = 5e-5       
    EPOCHS = 2                 # Reduced for testing
    WARMUP_STEPS = 100         
    WEIGHT_DECAY = 0.01        
    MAX_GRAD_NORM = 1.0        
    ACCUMULATION_STEPS = 8     # Increased to maintain effective batch size

    # 🚀 DATASET SIZE OPTIMIZATION: Use medium subset for good balance
    TRAIN_SIZE = 50000         # Use 50K training samples (vs 560K)
    TEST_SIZE = 10000          # Use 10K test samples (vs 38K)
    
    print(f"🚀 BIGBIRD BLOCK SPARSE SETTINGS:")
    print(f"  Dataset Size: {TRAIN_SIZE} train + {TEST_SIZE} test (MEDIUM SUBSET)")
    print(f"  Attention Type: {ATTENTION_TYPE}")
    print(f"  Max Length: {MAX_LENGTH} (optimized for block_size {BLOCK_SIZE})")
    print(f"  Block Size: {BLOCK_SIZE}")
    print(f"  Random Blocks: {NUM_RANDOM_BLOCKS}")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  Accumulation Steps: {ACCUMULATION_STEPS}")
    print(f"  Effective Batch Size: {BATCH_SIZE * ACCUMULATION_STEPS}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Mixed Precision: {'Yes' if torch.cuda.is_available() else 'No'}")
    
    # Memory optimization: Clear cache and set memory fraction
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.set_per_process_memory_fraction(0.8)

    # Load medium subset of Yelp data for good balance of speed and accuracy
    train_texts, train_labels, test_texts, test_labels = load_yelp_data_fast(
        train_size=TRAIN_SIZE, 
        test_size=TEST_SIZE
    )

    # Initialize BigBird tokenizer
    print("Loading BigBird tokenizer...")
    tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')

    # Load base BigBird model
    print("Loading base BigBird model...")
    config = BigBirdConfig.from_pretrained('google/bigbird-roberta-base')
    config.num_labels = 2  # Binary classification for Yelp sentiment

    # IMPORTANT: Validate block size compatibility
    if MAX_LENGTH % BLOCK_SIZE != 0:
        print(f"⚠️ Warning: MAX_LENGTH ({MAX_LENGTH}) is not divisible by BLOCK_SIZE ({BLOCK_SIZE})")
        MAX_LENGTH = (MAX_LENGTH // BLOCK_SIZE) * BLOCK_SIZE
        print(f"Adjusted MAX_LENGTH to: {MAX_LENGTH}")

    # Create optimized BigBird model with block sparse attention
    print(f"Creating BigBird model with block sparse attention...")
    model = OptimizedBigBirdForSequenceClassification(
        config,
        attention_type=ATTENTION_TYPE,
        block_size=BLOCK_SIZE,
        num_random_blocks=NUM_RANDOM_BLOCKS
    )

    # Load pretrained weights for the BigBird backbone
    print("Loading pretrained BigBird weights...")
    try:
        # Use checkpoint loading to save memory
        pretrained_model = BigBirdModel.from_pretrained(
            'google/bigbird-roberta-base',
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        model.bigbird.load_state_dict(pretrained_model.state_dict())
        print("✅ Successfully loaded pretrained BigBird weights")
        
        # Delete pretrained model to free memory
        del pretrained_model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
    except Exception as e:
        print(f"⚠️ Warning: Could not load pretrained weights: {e}")
        print("Continuing with random initialization...")

    # Initialize classifier weights
    nn.init.normal_(model.classifier.weight, std=0.02)
    nn.init.zeros_(model.classifier.bias)

    # Move to device and convert to half precision if using CUDA
    model.to(device)
    if torch.cuda.is_available():
        model = model.half()  # Use half precision to save memory

    print(f"Model loaded with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")

    # Create data loaders
    train_loader, test_loader = create_fast_data_loaders(
        train_texts, train_labels, test_texts, test_labels,
        tokenizer, MAX_LENGTH, BATCH_SIZE
    )

    # Setup optimizer
    optimizer = AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        eps=1e-8
    )

    total_steps = len(train_loader) * EPOCHS // ACCUMULATION_STEPS
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        start_factor=0.1,
        total_iters=WARMUP_STEPS
    )

    print(f"Starting BigBird Block Sparse training for {EPOCHS} epochs on Yelp dataset...")

    # Training loop
    best_accuracy = 0

    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch + 1}/{EPOCHS}")

        # Train
        start_time = time.time()
        train_loss, train_accuracy = train_model_fast(
            model, train_loader, optimizer, scheduler, device, epoch,
            MAX_GRAD_NORM, ACCUMULATION_STEPS
        )
        train_time = time.time() - start_time

        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Training time: {train_time:.2f}s")

        # Quick evaluation
        model.eval()
        test_loss = 0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for batch in tqdm(test_loader, desc='Evaluating'):
                input_ids = batch['input_ids'].to(device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(device, non_blocking=True)
                labels = batch['labels'].to(device, non_blocking=True)

                if torch.cuda.is_available():
                    with torch.cuda.amp.autocast():
                        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                else:
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

                test_loss += outputs['loss'].item()
                preds = torch.argmax(outputs['logits'], dim=1)
                test_correct += (preds == labels).sum().item()
                test_total += labels.size(0)

        test_accuracy = test_correct / test_total
        print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {test_accuracy:.4f}")

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            print(f"New best accuracy: {best_accuracy:.4f}")

    print(f"\n🚀 BIGBIRD BLOCK SPARSE TRAINING COMPLETED!")
    print(f"✅ Best Accuracy: {best_accuracy:.4f}")
    print(f"✅ Training Speed: ~{train_time/60:.1f} min/epoch")
    print(f"✅ Total Time: ~{(train_time * EPOCHS)/60:.1f} minutes")

    print(f"\n🔍 BLOCK SPARSE ADVANTAGES:")
    print(f"✅ More memory efficient than original_full")
    print(f"✅ Faster computation due to sparse patterns")
    print(f"✅ Better scalability for long sequences")
    print(f"✅ Linear complexity O(n) instead of O(n²)")

    return model, best_accuracy


if __name__ == "__main__":
    # Memory check and optimization
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        
        # Check available GPU memory
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        allocated_memory = torch.cuda.memory_allocated() / 1e9
        available_memory = gpu_memory - allocated_memory
        
        print(f"🔍 GPU Memory Check:")
        print(f"  Total GPU Memory: {gpu_memory:.1f} GB")
        print(f"  Available Memory: {available_memory:.1f} GB")
        
        if available_memory < 4.0:  # Less than 4GB available
            print("⚠️ WARNING: Low GPU memory detected!")
            print("Block sparse should be more memory efficient than original_full")

    try:
        model, accuracy = main_bigbird_block_sparse()
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("\n❌ CUDA Out of Memory Error!")
            print("🔧 SUGGESTED SOLUTIONS for BLOCK SPARSE:")
            print("1. Reduce BLOCK_SIZE to 32 or 16")
            print("2. Reduce BATCH_SIZE to 4 or 2")
            print("3. Reduce MAX_LENGTH to 256 or 128")
            print("4. Use CPU instead of GPU")
        else:
            raise e
